-
Notifications
You must be signed in to change notification settings - Fork 510
Expand file tree
/
Copy pathmoe.py
More file actions
2157 lines (1955 loc) · 85.6 KB
/
moe.py
File metadata and controls
2157 lines (1955 loc) · 85.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MoE related Layers."""
import enum
import functools
import math
import random
from typing import Iterable, Optional, Tuple, Union
from aqt.jax.v2 import aqt_tensor as aqt
from flax import nnx
import jax
from jax import ad_checkpoint as adc
from jax.experimental import xla_metadata
from jax.sharding import NamedSharding, Mesh
from jax.sharding import PartitionSpec as P
import jax.numpy as jnp
from MaxText import common_types as ctypes
from MaxText.common_types import ShardMode
from MaxText.sharding import maybe_shard_with_logical, create_sharding
from MaxText.sharding import logical_to_mesh_axes
from MaxText.layers import attentions, linears, nnx_wrappers, quantizations
from MaxText.layers.initializers import NdInitializer, default_bias_init, nd_dense_init, variable_to_logically_partitioned
from maxtext.kernels import megablox as mblx
from maxtext.utils import max_logging
from maxtext.utils import max_utils
import numpy as np
import qwix.pallas as qpl
import tokamax
set_xla_metadata = xla_metadata.set_xla_metadata
DISPATCH = "dispatch"
COMBINE = "combine"
def _sort_activations(
inputs: jax.Array,
sort_indices: jax.Array,
use_custom_vjp: bool,
) -> jax.Array:
"""Sort activations by `sort_indices`.
If `use_custom_vjp=True`, then we use a custom backward pass that
reverses the sort order. Specifically, this unsort operation is simply a sort
with `jnp.argsort(sort_indices)` as the sort indices. This is only needed in
the case where the compiler generates a less efficient backward pass op.
Note that `use_custom_vjp=True` assumes that `sort_indices` is a permutation
of `jnp.arange(inputs.shape[0])`.
Args:
inputs: `(tokens, ...)`-shaped array of input activations to sort.
sort_indices: `(tokens,)`-shaped array containing the sort order.
use_custom_vjp: Whether to use the explicit backward pass.
Returns:
`(tokens, ...)`-shaped array of input activations sorted by `sort_indices`.
"""
assert inputs.shape[0] == sort_indices.shape[0]
with jax.named_scope("sort_activations"):
if use_custom_vjp:
return _sort_activations_custom(inputs, sort_indices)
return inputs[sort_indices, ...]
@jax.custom_vjp
def _sort_activations_custom(inputs: jax.Array, sort_indices: jax.Array) -> jax.Array:
"""Sort functions with custom vjp."""
return inputs[sort_indices, ...]
def _sort_activations_custom_fwd(inputs: jax.Array, sort_indices: jax.Array) -> tuple[jax.Array, jax.Array]:
"""Forward pass of the custom vjp for `_sort_activations()`."""
return _sort_activations_custom(inputs, sort_indices), sort_indices
def _sort_activations_custom_bwd(residuals: jax.Array, grads: jax.Array) -> tuple[jax.Array, None]:
"""Backward pass of the custom vjp for `_sort_activations()`."""
sort_indices = residuals
return _sort_activations_custom(grads, jnp.argsort(sort_indices)), None
_sort_activations_custom.defvjp(_sort_activations_custom_fwd, _sort_activations_custom_bwd)
def random_routing(rng_key, gate_logits, num_experts_per_tok):
"""Performs random routing of tokens to experts.
Args:
rng_key: A JAX PRNGKey for randomness.
gate_logits: A JAX array of shape (batch_size, sequence_length, num_experts)
representing the logits for each expert.
num_experts_per_tok: The number of experts to select for each token.
Returns:
A tuple containing:
- top_k_indices: JAX array of shape (batch_size, sequence_length,
num_experts_per_tok)
representing the indices of the selected experts for each
token.
- top_k_weights: JAX array of shape (batch_size, sequence_length,
num_experts_per_tok)
representing the weights for the selected experts.
"""
bs, seq_len, num_experts = gate_logits.shape
selected_num = bs * seq_len * num_experts_per_tok
# Directly generate random integers in the range [0, num_experts)
top_k_indices = jax.random.randint(
rng_key,
shape=(selected_num,),
minval=0,
maxval=num_experts,
dtype=jnp.int32,
)
top_k_indices = top_k_indices.reshape(bs, seq_len, num_experts_per_tok)
top_k_weights = jnp.take_along_axis(gate_logits, top_k_indices, axis=-1)
return top_k_weights, top_k_indices
def calculate_load_balance_updates(top_k_indices, num_experts, rate):
"""
Computes a bias adjustment update based on expert load.
Used in DeepSeek V3: https://arxiv.org/html/2412.19437v1.
Implementation reference: https://arxiv.org/pdf/2408.15664.
Args:
top_k_indices: Shape (batch, sequence, top_k).
num_experts: Total number of experts.
rate: The update rate.
Returns:
update: The value to add to the expert bias. Shape (num_experts,).
"""
flat_indices = top_k_indices.ravel()
expert_counts = jnp.bincount(flat_indices, length=num_experts)
total_tokens = flat_indices.size
average_load = total_tokens / num_experts
direction = jnp.sign(average_load - expert_counts)
output = direction * rate
return output
class GateLogit(nnx.Module):
"""A layer used to compute gate logits, allowing to return the pre bias values for DeepSeek routing."""
def __init__(
self,
in_features_shape: Union[Iterable[int], int],
out_features_shape: Union[Iterable[int], int],
model_name: str,
mesh: Mesh,
rngs: nnx.Rngs,
axis: Union[Iterable[int], int] = -1,
weight_dtype: ctypes.DType = jnp.float32,
dtype: ctypes.DType = jnp.float32,
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes: Tuple[Optional[str], ...] = (),
use_bias: bool = False,
score_func: str = "",
quant: Optional[quantizations.AqtQuantization] = None,
shard_mode: ShardMode = ShardMode.AUTO,
matmul_precision: str = "default",
):
"""Initializes the GateLogit module.
Attributes:
in_features_shape: The shape of the input features.
out_features_shape: The shape of the output features, typically the number of experts.
model_name: The name of the model.
rngs: An `nnx.Rngs` object used for initializing parameters.
axis: The axis or axes over transformation is applied.
weight_dtype: The data type of the kernel weights.
dtype: The data type for the computation.
kernel_init: The initializer function for the kernel weight matrix.
kernel_axes: A tuple of logical axis names for partitioning the kernel.
use_bias: Whether to add learnable bias in gate logit scores. When enabled,
this bias aids expert load balancing (like in DeepSeek V3), and is not
part of the loss calculation.
score_func: Scoring function for output normalization before applying bias.
quant: The quantization configuration. If None, no quantization is applied.
matmul_precision: The precision level for the matrix multiplication.
"""
self.in_features_shape = linears.canonicalize_tuple(in_features_shape)
self.out_features_shape = linears.canonicalize_tuple(out_features_shape)
self.model_name = model_name
self.mesh = mesh
self.axis = linears.canonicalize_tuple(axis)
self.weight_dtype = weight_dtype
self.dtype = dtype
self.kernel_init = kernel_init
self.kernel_axes = kernel_axes
self.use_bias = use_bias
self.score_func = score_func
self.quant = quant
self.shard_mode = shard_mode
self.matmul_precision = matmul_precision
# Parameter initialization
kernel_shape = self.in_features_shape + self.out_features_shape
kernel_in_axis = np.arange(len(self.axis))
kernel_out_axis = np.arange(len(self.axis), len(self.axis) + len(self.out_features_shape))
if not quantizations.in_serve_mode(self.quant):
self.kernel = nnx.Param(
self.kernel_init(
rngs.params(),
kernel_shape,
self.weight_dtype,
kernel_in_axis,
kernel_out_axis,
),
sharding=self.kernel_axes,
)
if self.use_bias:
bias_axes = self.kernel_axes[-len(self.out_features_shape) :]
bias_shape = kernel_shape[-len(self.out_features_shape) :]
self.bias = nnx.Param(
default_bias_init(rngs.params(), bias_shape, self.weight_dtype),
sharding=bias_axes,
)
else:
self.bias = None
if quant:
dot_general_cls = quant.dot_general_cls(mesh_axes=kernel_axes)
dot_general_linen = dot_general_cls()
quant_dot_general = nnx_wrappers.ToNNX(dot_general_linen, rngs=rngs)
self._quant_dot_general_name = f"{type(dot_general_linen).__name__}_0"
setattr(self, self._quant_dot_general_name, quant_dot_general)
dummy_inputs = jnp.zeros((1, *self.in_features_shape), dtype=self.dtype)
self(dummy_inputs, _initializing=True)
else:
self._quant_dot_general_name = None
@property
def quant_dot_general(self) -> nnx_wrappers.ToNNX | None:
if self._quant_dot_general_name is None:
return None
return getattr(self, self._quant_dot_general_name)
def __call__(self, inputs: jax.Array, _initializing: bool = False) -> Tuple[jax.Array, Optional[jax.Array]]:
inputs = jnp.asarray(inputs, self.dtype)
norm_axis = linears.normalize_axes(self.axis, inputs.ndim)
if quantizations.in_serve_mode(self.quant):
kernel_shape = self.in_features_shape + self.out_features_shape
kernel = jnp.zeros(kernel_shape, dtype=self.dtype)
else:
kernel = self.kernel[...]
kernel = jnp.asarray(kernel, self.dtype)
contract_ind = tuple(range(0, len(norm_axis)))
output_sharding = (
create_sharding(self.mesh, ("activation_batch_no_exp", "activation_length_no_exp", None))
if self.shard_mode == ShardMode.EXPLICIT
else None
)
output = linears._compute_dot_general_nnx(
inputs,
kernel,
norm_axis,
contract_ind,
self.matmul_precision,
self.quant_dot_general,
_initializing,
out_sharding=output_sharding,
)
pre_bias_logits = None
if self.score_func:
output = linears._convert_to_activation_function(self.score_func)(output)
if self.model_name.startswith("deepseek3"):
pre_bias_logits = output
if self.use_bias:
bias = jnp.asarray(self.bias[...], self.dtype)
output += bias
return output, pre_bias_logits
class RoutedMoE(nnx.Module):
"""Implements a routed MoE block."""
def __init__(
self,
config: ctypes.Config,
num_experts: int,
num_experts_per_tok: int,
mesh: jax.sharding.Mesh,
kernel_init: attentions.NdInitializer,
kernel_axes: Tuple[Optional[str], ...],
rngs: nnx.Rngs,
intermediate_dim: int = 2048,
weight_dtype: ctypes.DType = jnp.float32,
dtype: ctypes.DType = jnp.float32,
quant: Optional[quantizations.AqtQuantization] = None,
):
"""Initializes the RoutedMoE module.
Attributes:
config: The main config setting.
num_experts: Number of experts.
num_experts_per_tok: Number of experts for each token.
mesh: Mesh, device mesh.
kernel_init: The initializer function for the kernel weight matrix.
kernel_axes: A tuple of logical axis names for partitioning the kernel.
rngs: An `nnx.Rngs` object used for initializing parameters.
intermediate_dim: Intermediate dimension of MoE.
weight_dtype: The data type of the kernel weights.
dtype: The data type for the computation.
quant: The quantization configuration. If None, no quantization is applied.
"""
self.config = config
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.mesh = mesh
self.kernel_init = kernel_init
self.kernel_axes = kernel_axes
self.intermediate_dim = intermediate_dim
self.weight_dtype = weight_dtype
self.dtype = dtype
self.quant = quant
self.rngs = rngs
if self.config.shard_exp_on_fsdp:
# special sharding for dsv3
self.wi_kernel_axes = ("embed_no_exp", None, "mlp")
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
elif self.config.use_2d_fsdp_sharding:
self.wi_kernel_axes = ("embed_no_exp", "mlp", None)
self.wo_kernel_axes = ("embed_no_exp", "mlp", None)
else:
self.wi_kernel_axes = ("exp", "embed_no_exp", "mlp")
self.wo_kernel_axes = ("exp", "mlp", "embed_no_exp")
if self.config.attention == "vllm_rpa":
# vLLM uses 'model' as the tensor parallelism axis name
self._tensor_parallelism_name = ("model", "attn_dp")
else:
self._tensor_parallelism_name = "tensor"
self.gate = GateLogit(
in_features_shape=self.config.emb_dim,
out_features_shape=self.num_experts,
mesh=self.mesh,
model_name=self.config.model_name,
dtype=self.dtype,
weight_dtype=self.weight_dtype,
quant=self.quant,
kernel_init=self.kernel_init,
kernel_axes=self.kernel_axes,
use_bias=self.config.routed_bias,
score_func=self.config.routed_score_func,
matmul_precision=self.config.matmul_precision,
shard_mode=config.shard_mode,
rngs=self.rngs,
)
# pylint: disable=protected-access
self.activation_fn = linears._convert_to_activation_function(self.config.mlp_activations[0])
kernel_in_axis = np.arange(1)
kernel_out_axis = np.arange(1, 2)
if quantizations.in_serve_mode(self.quant):
# During aqt convert state we delete kernel weight from params to save
# memory. Instead they are retrieved from the tensors stored in the 'aqt'
# collection.
self.wi_0 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim))
self.wi_1 = jnp.zeros((num_experts, self.config.emb_dim, intermediate_dim))
self.wo = jnp.zeros((num_experts, intermediate_dim, self.config.emb_dim))
else:
self.wi_0 = nnx.Param(
self.kernel_init(
self.rngs.params(),
(num_experts, self.config.emb_dim, intermediate_dim),
weight_dtype,
kernel_in_axis,
kernel_out_axis,
),
sharding=self.wi_kernel_axes,
)
self.wi_1 = nnx.Param(
self.kernel_init(
self.rngs.params(),
(num_experts, self.config.emb_dim, intermediate_dim),
weight_dtype,
kernel_in_axis,
kernel_out_axis,
),
sharding=self.wi_kernel_axes,
)
self.wo = nnx.Param(
self.kernel_init(
self.rngs.params(),
(self.num_experts, self.intermediate_dim, self.config.emb_dim),
self.weight_dtype,
kernel_in_axis,
kernel_out_axis,
),
sharding=self.wo_kernel_axes,
)
if self.config.mlp_bias:
wi_bias_axes = ("exp", "activation_mlp")
wo_bias_axes = ("exp", "activation_embed")
wi_bias_shape = (self.num_experts, self.intermediate_dim)
wo_bias_shape = (self.num_experts, self.config.emb_dim)
self.wi_0_bias = nnx.Param(
default_bias_init(self.rngs.params(), wi_bias_shape, self.weight_dtype),
sharding=wi_bias_axes,
)
self.wi_1_bias = nnx.Param(
default_bias_init(self.rngs.params(), wi_bias_shape, self.weight_dtype),
sharding=wi_bias_axes,
)
self.wo_bias = nnx.Param(
default_bias_init(self.rngs.params(), wo_bias_shape, self.weight_dtype),
sharding=wo_bias_axes,
)
else:
self.wi_0_bias = None
self.wi_1_bias = None
self.wo_bias = None
def _maybe_shard_with_logical(self, inputs, logical_name):
return maybe_shard_with_logical(
inputs,
logical_name,
mesh=self.mesh,
shard_mode=self.config.shard_mode,
debug_sharding=self.config.debug_sharding,
)
def _logical_to_mesh_axes(self, logical_name):
return logical_to_mesh_axes(logical_name, mesh=self.mesh, rules=self.config.logical_axis_rules)
def get_expert_parallelism_size(self):
return self.mesh.shape.get("expert", 1)
def get_tensor_parallelism_size(self):
if isinstance(self._tensor_parallelism_name, tuple):
size = 1
for axis in self._tensor_parallelism_name:
size *= self.mesh.shape.get(axis, 1)
return size
return self.mesh.shape.get(self._tensor_parallelism_name, 1)
def get_tensor_transpose_parallelism_size(self):
return self.mesh.shape.get("tensor_transpose", 1)
def get_context_autoregressive_parallelism_size(self):
return self.mesh.shape.get("context_autoregressive", 1)
def should_update_load_balance(self):
"""Determines if loss-free load balancing updates should be applied."""
return self.config.routed_bias and self.config.routed_bias_update_rate > 0.0
def get_topk(self, gate_logits, pre_bias_logits, rngs=None):
"""get topk."""
# shape of top_k_weights & top_k_indices:
# (batch, sequence, num_experts_per_tok).
if self.config.use_random_routing:
if rngs is None:
raise ValueError("The random key cannot be None for random routing.")
# Reuse the 'dropout' RNG stream to ensure random routing
rng = rngs.dropout()
top_k_weights, top_k_indices = random_routing(rng, gate_logits, self.num_experts_per_tok)
return top_k_weights, top_k_indices
if self.config.model_name.startswith("deepseek3"):
top_k_weights, top_k_indices = self.deepseek_routing(gate_logits, pre_bias_logits)
else:
top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok)
if self.config.decoder_block == ctypes.DecoderBlockType.DEEPSEEK:
top_k_weights = self.deepseek_scale_weights(top_k_weights)
elif self.config.decoder_block != ctypes.DecoderBlockType.LLAMA4:
top_k_weights = jax.nn.softmax(top_k_weights.astype(jnp.float32), axis=-1).astype(self.dtype)
# This is the Qwen3-specific normalization of router weights.
if self.config.norm_topk_prob:
top_k_weights /= top_k_weights.sum(axis=-1, keepdims=True)
return top_k_weights, top_k_indices
def deepseek_scale_weights(self, weights):
"""Scales weights according to DeepSeek's v3 reference implementation."""
# https://github.com/deepseek-ai/DeepSeek-V3/blob/2f7b80eecebf3d1c84da5a0d465f6639ea175012/inference/model.py#L592-L594.
if self.config.routed_score_func == "sigmoid":
weights /= weights.sum(-1, keepdims=True)
weights *= self.config.routed_scaling_factor
return weights
def expert_group_mask(self, gate_logits: jax.Array) -> jax.Array:
"""Returns a mask that selects only the top-k groups of experts.
Groups of experts are selected based on the sum of the top-2 expert scores
for each group.
Args:
gate_logits: Array of shape `(batch, seq, num_experts)`.
Returns:
Array of shape `(batch, seq, num_experts)` that is 1 for experts in the
top-k groups and 0 elsewhere.
"""
# Find top groups based on each group's top-2 expert scores, where
# `scores_grouped.shape =
# (batch * seq, n_routing_groups, experts_per_group)`.
scores_grouped = jnp.reshape(
gate_logits,
gate_logits.shape[:-1] + (self.config.n_routing_groups, -1),
)
top2_in_group_vals, _ = jax.lax.top_k(scores_grouped, k=2)
group_scores = jnp.sum(jnp.astype(top2_in_group_vals, jnp.float32), axis=-1)
_, group_idx = jax.lax.top_k(group_scores, k=self.config.topk_routing_group)
# Mask selected groups so that only those experts are considered.
group_mask = jax.nn.one_hot(group_idx, num_classes=self.config.n_routing_groups, dtype=jnp.float32)
group_mask = jnp.sum(group_mask, axis=-2)
# Apply masks and get top-k indices.
score_mask_expanded = jnp.broadcast_to(
group_mask[..., None],
group_mask.shape + (self.num_experts // self.config.n_routing_groups,),
)
return jnp.reshape(
score_mask_expanded,
score_mask_expanded.shape[:-2] + (self.num_experts,),
)
def deepseek_routing(self, gate_logits: jax.Array, pre_bias_logits: jax.Array) -> tuple[jax.Array, jax.Array]:
"""DeepSeek routing logit.
If the configuration does not specify routing groups (`n_routing_groups` is
-1), we use a standard top-k routing mechanism. Otherwise, we force all
selected experts to be from the a subset of the highest rated expert groups.
The selection process uses post_bias logits, while the return weights use
pre_bias logits.
Args:
gate_logits: Array of shape `(batch, seq, num_experts)`.
pre_bias_logits: Array of shape `(batch, seq,num_experts)`.
Returns:
- top_k_weights: `(batch, seq, num_experts_per_tok)` array of weight values for
each selected expert.
- top_k_indices: `(batch, seq, num_experts_per_tok)` array of indices
identifying the selected experts for each token.
"""
expert_mask = 1 if self.config.n_routing_groups == -1 else self.expert_group_mask(gate_logits)
_, top_k_indices = jax.lax.top_k(
jnp.where(expert_mask > 0, gate_logits, -jnp.inf),
k=self.num_experts_per_tok,
)
top_k_weights = jnp.take_along_axis(pre_bias_logits, top_k_indices, axis=-1)
return top_k_weights, top_k_indices
def apply_ffn_activation(self, layer_w0, layer_w1):
"""Applies FFN activation function."""
with jax.named_scope("ffn_act"):
if self.config.decoder_block == ctypes.DecoderBlockType.GPT_OSS:
layer_w0 = jnp.clip(layer_w0, a_min=None, a_max=self.config.mlp_activations_limit)
layer_w1 = jnp.clip(layer_w1, a_min=-self.config.mlp_activations_limit, a_max=self.config.mlp_activations_limit)
layer_act = self.activation_fn(layer_w0 * 1.702)
glu = jnp.multiply(layer_w0, layer_act)
intermediate_layer = jnp.multiply(glu, (layer_w1 + 1))
else:
layer_act = self.activation_fn(layer_w0)
intermediate_layer = jnp.multiply(layer_act, layer_w1)
return intermediate_layer.astype(self.dtype)
def permute(self, inputs, gate_logits, pre_bias_logits, use_custom_sort_vjp=True, rngs=None, roll_to_expert_id=None):
"""Permute tokens to group by expert to fit gmm call."""
# reshape inputs (batch, sequence, emb) to (batch * sequence, emb)
inputs_shape = inputs.shape
bsz_times_seq_len = inputs_shape[0] * inputs_shape[1]
inputs_2d = jnp.reshape(inputs, (bsz_times_seq_len, inputs_shape[2]))
weights, selected_experts = self.get_topk(gate_logits, pre_bias_logits, rngs)
lb_loss = None
if self.config.load_balance_loss_weight > 0.0:
softmax_probs = jax.nn.softmax(gate_logits.astype(jnp.float32), axis=-1).astype(self.dtype)
lb_loss = self.load_balance_loss(selected_experts, softmax_probs)
if self.should_update_load_balance():
bias_updates = calculate_load_balance_updates(
selected_experts, self.config.num_experts, self.config.routed_bias_update_rate
)
else:
bias_updates = None
if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4:
# weights will be of shape (batch_size, seq_len, num_experts_per_tok)
router_scores = jax.nn.sigmoid(weights.astype(jnp.float32)) # weights are top_k_weights here
# Squeeze router_scores to (batch_size * seq_len, num_experts_per_tok)
inputs_2d = inputs_2d * router_scores.reshape(bsz_times_seq_len, -1)
flatten_selected_experts = jnp.ravel(selected_experts)
if roll_to_expert_id is not None:
flatten_selected_experts = (flatten_selected_experts - roll_to_expert_id) % self.num_experts
sorted_selected_experts = jnp.argsort(flatten_selected_experts)
# sort inputs for number of selected experts
replicated_inputs_2d = jnp.repeat(inputs_2d, self.num_experts_per_tok, axis=0)
sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts, use_custom_sort_vjp).astype(
self.dtype
)
group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts)
# Return the experts for each sorted input.
expert_indices = jnp.arange(self.num_experts)
sorted_experts = jnp.repeat(
expert_indices,
repeats=group_size,
total_repeat_length=flatten_selected_experts.shape[0],
)
return (
sorted_inputs,
sorted_selected_experts,
weights,
group_size,
sorted_experts,
lb_loss,
bias_updates,
)
def unpermute(
self,
intermediate,
sorted_selected_experts,
weights,
batch_size,
sequence_length,
use_custom_sort_vjp=True,
):
"""Unpermute tokens to original order and combine weights."""
unsort_intermediate = _sort_activations(
intermediate,
jnp.argsort(sorted_selected_experts),
use_custom_sort_vjp,
)
reshaped_weights = jnp.reshape(weights, (-1, self.num_experts_per_tok))
reshaped_intermediate = jnp.reshape(
unsort_intermediate,
(reshaped_weights.shape[0], self.num_experts_per_tok, -1),
)
with jax.named_scope("weight_sum"):
matmul_precision = jax.lax.Precision(self.config.matmul_precision)
if self.config.decoder_block == ctypes.DecoderBlockType.LLAMA4:
# For Llama4, combine using weights of 1 for selected experts
reshaped_weights = jnp.ones_like(reshaped_weights)
if self.config.float32_weight_sum:
reshaped_intermediate = reshaped_intermediate.astype(jnp.float32)
reshaped_weights = reshaped_weights.astype(jnp.float32)
output = jnp.einsum(
"BKE,BK -> BE",
reshaped_intermediate,
reshaped_weights,
precision=matmul_precision,
)
return output.reshape(batch_size, sequence_length, -1).astype(self.dtype)
@staticmethod
def local_permute(
inputs,
global_group_sizes,
local_expert_size,
shard_index,
is_offset=False,
global_sorted_experts=None,
use_custom_sort_vjp=True,
):
"""Permutes tokens locally within an expert shard.
This function prepares the input tokens for processing by the experts
located
on the current shard. It groups the tokens by their assigned local expert
index (0 to local_expert_size - 1).
Args:
inputs: The input data (tokens) assigned to the experts on this shard.
Shape `[tokens, emb_dim]`.
global_group_sizes: The count of tokens assignments for each global expert
across all the batch shards. Shape `[num_batch_shards, num_experts].
local_expert_size: The number of experts handled by the current shard.
shard_index: The index of the current expert shard (0 to
num_expert_parallelism - 1).
is_offset: If True, assumes `inputs` are pre-sorted by global expert ID
and selects the slice relevant to this shard's assigned experts. If
False, assumes that `inputs` corresponding to the shard's experts start
from the beginning of the tensor but need to be permuted by expert ID.
global_sorted_experts: Global expert IDs for the `inputs` used when
`is_offset` is True. Shape `[total_tokens_for_this_shard]`.
Returns:
A tuple containing:
sorted_inputs: Input data permuted local expert ID.
sorted_indices: Indices used to permute the inputs.
local_group_size: Number of tokens assigned to each local expert on this
shard.
sorted_experts_ids: expert ID corresponding to each token of the permuted
inputs.
"""
# Slice the count of local expert IDs in each batch shard.
# all_shard_local_sizes.shape: [expert_shard, local_expert_size]
all_shard_local_sizes = jax.lax.dynamic_slice_in_dim(
global_group_sizes,
shard_index * local_expert_size,
local_expert_size,
axis=1,
)
local_sizes = all_shard_local_sizes.reshape(-1)
# Total count of the local expert IDs is the sum of the counts across all
# batch shards, since all batch shards will send their contributions to the
# current expert shard.
local_group_size = jnp.sum(all_shard_local_sizes, axis=0)
# In this case, the data that needs to be processed by the local shard
# does not start from row 0 but actually starts at
# (jnp.concatenate((jnp.array([0]),
# jnp.cumsum(local_group_sizes[:-1]))[shard_id]).
# This happens if batches (`inputs`) are replicated across expert shards and
# pre-sorted by global Expert ID (via permute()).
if is_offset:
divided_assignments = jnp.floor_divide(global_sorted_experts, local_expert_size)
expert_indices = jnp.where(
divided_assignments == shard_index,
jnp.mod(global_sorted_experts, local_expert_size),
local_expert_size,
)
# In this case the `input` data has been received from the batch shards and
# needs to be reorganized in order of local Expert IDs.
else:
base_indices = jnp.mod(jnp.arange(local_sizes.shape[0]), local_expert_size)
expert_indices = jnp.repeat(base_indices, local_sizes, total_repeat_length=inputs.shape[0])
sorted_indices = jnp.argsort(expert_indices)
sorted_inputs = _sort_activations(inputs, sorted_indices, use_custom_sort_vjp)
sorted_experts_ids = expert_indices[sorted_indices]
return (
sorted_inputs,
sorted_indices,
local_group_size,
sorted_experts_ids,
)
@staticmethod
def get_all_to_all_params(
all_shards_group_sizes,
shard_id,
num_expert_parallelism,
is_batch_sharded=True,
):
"""Generates input offsets, send sizes, output offsets, and receive sizes used for ragged_all_to_all."""
class TransformStrategy(enum.Enum):
INPUT_OFFSET = enum.auto()
SEND_SIZE = enum.auto()
OUTPUT_OFFSET = enum.auto()
RECV_SIZE = enum.auto()
def transform_array(input_array, shard_id, strategy, is_batch_sharded):
"""Transforms the input array based on the specified strategy."""
# Prepares it for the usage with `ragged_all_to_all` API. The
# transformation determines how data is sent and received between shards.
if is_batch_sharded:
if strategy == TransformStrategy.INPUT_OFFSET:
# Index of input array for the send
local_array = input_array[shard_id]
return jnp.concatenate((jnp.array([0]), jnp.cumsum(local_array)[:-1]))
elif strategy == TransformStrategy.SEND_SIZE:
# Size of input array for the send
return input_array[shard_id]
elif strategy == TransformStrategy.OUTPUT_OFFSET:
# Received index in the target output
zero_row = jnp.zeros((1,) + input_array.shape[1:], dtype=input_array.dtype)
array_with_zeros = jnp.concatenate((zero_row, input_array), axis=0)
cumulated_array = jnp.cumsum(array_with_zeros, axis=0, dtype=input_array.dtype)
return cumulated_array[shard_id]
elif strategy == TransformStrategy.RECV_SIZE:
# Received size in the target output
return input_array[:, shard_id]
else:
raise ValueError(f"Unknown transform array strategy: {strategy}")
# If the batch is unsharded then we send the same data slice to all other
# shards. We also assume each shard will have the local processed inputs
# sorted to start from index 0. Finally, len(input_array.shape) == 1 since
# there is only one batch shard.
else:
if strategy == TransformStrategy.INPUT_OFFSET:
# The data on each shard always starts at 0.
return jnp.zeros(num_expert_parallelism, dtype=input_array.dtype)
elif strategy == TransformStrategy.SEND_SIZE:
# The send amount is always the amount of data the current expert
# shard needs to process.
return jnp.repeat(input_array[shard_id], num_expert_parallelism)
elif strategy == TransformStrategy.OUTPUT_OFFSET:
# The offset in each shard will just be the start of the group which
# that shard is responsible for.
output_offset = jnp.concatenate((jnp.array([0]), jnp.cumsum(input_array[:-1])))[shard_id]
return jnp.repeat(output_offset, num_expert_parallelism)
# The amount that each shard receives from all other shards is
# equivalent to the group sizes (aka input_array).
elif strategy == TransformStrategy.RECV_SIZE:
# Received size in the target output
return input_array
else:
raise ValueError(f"Unknown transform array strategy: {strategy}")
input_offsets = transform_array(
all_shards_group_sizes,
shard_id,
TransformStrategy.INPUT_OFFSET,
is_batch_sharded,
)
send_sizes = transform_array(
all_shards_group_sizes,
shard_id,
TransformStrategy.SEND_SIZE,
is_batch_sharded,
)
output_offsets = transform_array(
all_shards_group_sizes,
shard_id,
TransformStrategy.OUTPUT_OFFSET,
is_batch_sharded,
)
recv_sizes = transform_array(
all_shards_group_sizes,
shard_id,
TransformStrategy.RECV_SIZE,
is_batch_sharded,
)
return input_offsets, send_sizes, output_offsets, recv_sizes
def transform_bias(self, experts_index, *biases):
"""Selects bias values for a variable number of bias tensors based on chosen experts."""
return tuple(bias[experts_index] for bias in biases)
def sparse_matmul(
self,
inputs,
gate_logits,
pre_bias_logits,
w0_kernel,
w1_kernel,
wo_kernel,
w0_bias,
w1_bias,
wo_bias,
):
"""Perform sparse matrix multiplication of inputs and Experts."""
def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
pad_length = self.config.wi_tile_fwd_batch_seq
hs_shape = inputs.shape
# pad length is the 1st dimension of tiling size in gmm call
if inputs.shape[0] != expert_assignments.shape[0]:
raise ValueError("The number of input tokens must match the number of expert" " assignments!")
padding_amount = 0
if hs_shape[0] % pad_length:
padding_amount = pad_length - hs_shape[0] % pad_length
inputs = jax.lax.pad(inputs, jnp.array(0.0, dtype=inputs.dtype), [(0, padding_amount, 0), (0, 0, 0)])
inputs = inputs.astype(self.dtype)
kernel = kernel.astype(self.dtype)
lhs_quantize_dtype, rhs_quantize_dtype = None, None
if self.quant is not None:
quant_dg = self.quant.quant_dg
lhs_quantize_dtype = quant_dg.fwd.dg_quantizer.lhs.numerics.get_dtype()
rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype()
m, k, n = inputs.shape[0], inputs.shape[1], kernel.shape[2]
if not self.config.megablox and not self.config.use_tokamax_gmm:
tiling = (
min(tiling[0], m),
min(tiling[1], k),
min(tiling[2], n),
)
if self.config.use_tokamax_gmm:
if self.config.quantization:
output = mblx.gmm(
lhs=inputs,
rhs=kernel,
group_sizes=group_sizes,
preferred_element_type=self.dtype,
tiling=tiling,
lhs_quantize_dtype=lhs_quantize_dtype,
rhs_quantize_dtype=rhs_quantize_dtype,
use_qwix_quantization=self.config.use_qwix_quantization,
use_tokamax_backend=self.config.use_tokamax_gmm,
weight_gather_axes=weight_gather_axes,
)
else:
output = tokamax.ragged_dot(
lhs=inputs,
rhs=kernel,
group_sizes=group_sizes,
precision=jax.lax.Precision.DEFAULT,
preferred_element_type=self.dtype,
implementation="mosaic",
)
else:
if self.config.megablox:
output = mblx.gmm(
lhs=inputs,
rhs=kernel,
group_sizes=group_sizes,
preferred_element_type=self.dtype,
tiling=tiling,
lhs_quantize_dtype=lhs_quantize_dtype,
rhs_quantize_dtype=rhs_quantize_dtype,
use_qwix_quantization=self.config.use_qwix_quantization,
use_tokamax_backend=self.config.use_tokamax_gmm,
weight_gather_axes=weight_gather_axes,
)
else:
rhs_inputs = kernel
if isinstance(kernel, aqt.QTensor):
if kernel.bias or kernel.sparsity_mask or len(kernel.scale) > 1:
raise ValueError("Unsupported usecase for ragged_dot with quantized kernel.")
rhs_inputs = kernel.qvalue
if self.config.use_qwix_quantization:
# Use full contraction for QWIX quantization to allow quantization
# fusion (max reduce over contracting dimension).
tiling = (tiling[0], k, tiling[2])
is_tpu = self.mesh.devices.flat[0] == "tpu"
# TPU needs random mosaic_fusion_group; GPU/CPU needs deterministic ID for autotuner sync
mosaic_group_id = f"{random.randint(0, 1000000000)}" if is_tpu else "0"
with set_xla_metadata(
ragged_dot_tiling=",".join([str(t) for t in tiling]),
mosaic_fusion_group=mosaic_group_id,
):
output = jax.lax.ragged_dot(
lhs=inputs,
rhs=rhs_inputs,
group_sizes=group_sizes,
preferred_element_type=self.dtype,
)
if isinstance(kernel, aqt.QTensor):
# Multiply outputs by the kernely scale
scales = jnp.take(kernel.scale[0].squeeze(), indices=expert_assignments, axis=0)
if padding_amount > 0:
scales = jax.lax.pad(
scales,
jnp.array(0.0, dtype=scales.dtype),
[(0, padding_amount, 0), (0, 0, 0)],
)
output *= scales
if padding_amount > 0:
output = output[: hs_shape[0]]
return output
# Currently, we support data, tensor, and expert parallelism with Megablox.
# We all gather the input activations over tensor parallelism to follow
# https://parsa.epfl.ch/course-info/cs723/papers/Megatron.pdf.
# Check if the batch should be sharded by expert and whether the batch_size
# supports this. For example, for interleaved inference, prefill always has
# batch_size=1 while decode can have batch_size > 1.
try:
is_batch_sharded_by_expert = (
"expert"
in tuple(
filter(
lambda tup: tup[0] == "activation_batch",
self.config.logical_axis_rules,
)
)[
0
][1]
)
except: # pylint: disable=bare-except