[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912tdophung wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR introduces
Confidence Score: 4/5Safe to merge for non-grouped-topk configs, but silently produces an incorrect auxiliary training objective for DeepSeek-style (num_groups/group_topk) models that also enable aux_loss_coeff. The aux loss routing map uses a clean standard top-k instead of the actual grouped-topk routing, making tokens_per_expert inconsistent with real routing decisions when num_groups/group_topk are set with aux_loss_coeff > 0. transformer_engine/jax/flax/moe.py — specifically _compute_aux_loss and the test_group_topk_deepseek test which does not exercise the aux loss path. Important Files Changed
Sequence DiagramsequenceDiagram
participant Input as Input [B,S,H]
participant Gate as _gate (einsum)
participant Router as _route_topk (fused_topk)
participant AuxLoss as _compute_aux_loss
participant Perm as _global_permute
participant A2A_Fwd as ragged_all_to_all (fwd)
participant LocalPerm as local_permute_after_a2a
participant FFN as _expert_ffn (grouped_dense x3)
participant LocalUnperm as local_unpermute_before_a2a
participant A2A_Rev as ragged_all_to_all (rev)
participant Combine as _global_combine
participant Output as Output [B,S,H]
Input->>Gate: inputs [B,S,H]
Gate->>Router: gate_logits [B,S,E]
Router->>AuxLoss: logits_2d (aux branch, parallel)
Router->>Perm: sparse_probs, routing_map
Perm->>A2A_Fwd: sorted_inputs [T*k,H] + group_sizes [E]
A2A_Fwd->>LocalPerm: x_recv [recv_buf, H]
LocalPerm->>FFN: sorted_x, local_group_sizes [E_local]
FFN->>LocalUnperm: expert_outputs [recv_buf, H]
LocalUnperm->>A2A_Rev: x_send_back
A2A_Rev->>Combine: y_back [T*k, H]
Combine->>Output: output [B,S,H] + aux_loss
Reviews (3): Last reviewed commit: "address greptile comments" | Re-trigger Greptile |
| def moe_permute( | ||
| inp: torch.Tensor, | ||
| routing_map: torch.Tensor, | ||
| num_out_tokens: int = -1, | ||
| num_out_tokens: int, | ||
| max_token_num: int = -1, | ||
| map_type: str = "mask", | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
Making
num_out_tokens a required positional argument is a breaking public API change. moe_permute and moe_permute_with_probs are both re-exported from transformer_engine.pytorch, so any downstream caller that relied on the old default of -1 (e.g. moe_permute(inp, routing_map)) will now receive a TypeError at runtime with no deprecation warning. Consider keeping the default and emitting a DeprecationWarning when the caller passes a negative value, or bumping the major version and documenting the change explicitly in the release notes.
| def moe_permute( | |
| inp: torch.Tensor, | |
| routing_map: torch.Tensor, | |
| num_out_tokens: int = -1, | |
| num_out_tokens: int, | |
| max_token_num: int = -1, | |
| map_type: str = "mask", | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| def moe_permute( | |
| inp: torch.Tensor, | |
| routing_map: torch.Tensor, | |
| num_out_tokens: int = -1, | |
| max_token_num: int = -1, | |
| map_type: str = "mask", | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| from ..permutation import ( | ||
| _routing_map_to_selected_experts, | ||
| compute_ragged_all_to_all_params, | ||
| compute_reverse_ragged_all_to_all_params, | ||
| local_permute_after_a2a, | ||
| local_unpermute_before_a2a, | ||
| token_combine, | ||
| token_dispatch, | ||
| unfused_token_combine, | ||
| unfused_token_dispatch, | ||
| ) |
There was a problem hiding this comment.
Import of private function across module boundary
_routing_map_to_selected_experts is imported from permutation.py by name with a leading underscore, signalling it is a private implementation detail. If this conversion helper is needed externally, it should be renamed (drop the underscore), added to __all__, and documented. As-is, a refactor inside permutation.py can silently break moe.py without any API-violation signal.
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: tdophung <tdophung@nvidia.com>
| def _compute_aux_loss( | ||
| self, | ||
| logits_2d: jnp.ndarray, | ||
| ) -> Optional[jnp.ndarray]: | ||
| """Compute the MoE auxiliary load-balancing loss. | ||
|
|
||
| The score-for-aux kernel has no data dependency on the main | ||
| routing kernel, so XLA can overlap them on the GPU. | ||
|
|
||
| ``logits_2d`` should be the *full* logits tensor over the global | ||
| token batch -- under EP the caller is responsible for | ||
| :func:`jax.lax.all_gather` ing the logits before calling this so | ||
| the aux_loss formula | ||
| ``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])`` | ||
| sees the global ``T`` and the global ``tokens_per_expert``. | ||
| """ | ||
| if self.aux_loss_coeff <= 0.0: | ||
| return None | ||
| aux_scores, aux_routing_map = fused_topk_with_score_function( | ||
| logits_2d.astype(jnp.float32), | ||
| topk=self.num_experts_per_tok, | ||
| score_function=self.score_function, | ||
| compute_aux_scores=True, | ||
| ) | ||
| aux_tokens_per_expert = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0) | ||
| return fused_moe_aux_loss( | ||
| aux_scores.astype(jnp.float32), | ||
| aux_tokens_per_expert, | ||
| topk=self.num_experts_per_tok, | ||
| coeff=self.aux_loss_coeff, | ||
| ) |
There was a problem hiding this comment.
Aux loss
tokens_per_expert is inconsistent with actual grouped-topk routing
When num_groups > 0 and group_topk > 0 (DeepSeek-style routing), fused_topk_with_score_function(..., compute_aux_scores=True) intentionally ignores those parameters and runs a clean standard top-k. The returned aux_routing_map therefore reflects different expert selections than the actual routing_map produced by _route_topk, causing aux_tokens_per_expert = sum(aux_routing_map, axis=0) to count a different token–expert distribution. Any user who combines num_groups > 0 + group_topk > 0 + aux_loss_coeff > 0 silently trains with a wrong auxiliary objective. The existing test_group_topk_deepseek test does not catch this because it leaves aux_loss_coeff at its default of 0.0.
Description
Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation with different sequence order of computation and communication, as well as different sharding rules and data layout storage.
This first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM,
option to choose between Triton kernels or pure JAX permutation implementation, and fused router kernel.
Fixes # (issue)
Type of change
Changes
Checklist: