Skip to content

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912

Open
tdophung wants to merge 8 commits intoNVIDIA:mainfrom
tdophung:teddy/moe_block
Open

[JAX] Add an MoE Block (Layer) that compound router, permutation, groupedGEMM and communication#2912
tdophung wants to merge 8 commits intoNVIDIA:mainfrom
tdophung:teddy/moe_block

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

Description

Most of MoE building blocks integration work has been deeply coupled with Maxtext development. Now creating this MoE block to isolate the work from Maxtext and provide more room for experimentation with different sequence order of computation and communication, as well as different sharding rules and data layout storage.

This first iteration will start with ring-of-experts EP, sharding on batch dimention for FSDP, CUBLASLt groupedGEMM,
option to choose between Triton kernels or pure JAX permutation implementation, and fused router kernel.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Create new MoELayer
  • Add test for MoELayer

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@tdophung tdophung marked this pull request as ready for review May 5, 2026 21:47
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 5, 2026

Greptile Summary

This PR introduces MoEBlock, a new self-contained Flax Linen layer that wires together TE's fused router, two permutation backends (pure-JAX argsort and Triton), grouped-GEMM expert FFN, and optional ragged-all-to-all expert parallelism via jax.shard_map. It also adds the unfused_token_dispatch/unfused_token_combine family, UnfusedPermState, and the compute_ragged_all_to_all_params helpers needed for the A2A EP path.

  • New MoEBlock: Decomposes the MoE forward into _route, _global_permute, _expert_ffn, _global_combine; the _forward_a2a_ep variant wraps the body in shard_map and inserts forward/reverse ragged_all_to_all + local permute around the FFN.
  • Unfused pure-JAX backend: unfused_token_dispatch/unfused_token_combine use jnp.argsort-based gather with a custom VJP; align_size > 0 is implemented but gated behind xfail tests.
  • EP helpers: compute_ragged_all_to_all_params and compute_reverse_ragged_all_to_all_params translate the gathered [num_ep, num_experts] token-count matrix into the four ragged_all_to_all offset/size arrays.

Confidence Score: 4/5

Safe to merge for non-grouped-topk configs, but silently produces an incorrect auxiliary training objective for DeepSeek-style (num_groups/group_topk) models that also enable aux_loss_coeff.

The aux loss routing map uses a clean standard top-k instead of the actual grouped-topk routing, making tokens_per_expert inconsistent with real routing decisions when num_groups/group_topk are set with aux_loss_coeff > 0.

transformer_engine/jax/flax/moe.py — specifically _compute_aux_loss and the test_group_topk_deepseek test which does not exercise the aux loss path.

Important Files Changed

Filename Overview
transformer_engine/jax/flax/moe.py New 974-line MoEBlock implementing no-EP and ragged-A2A-EP paths; aux loss routing map is inconsistent with actual grouped-topk routing when num_groups/group_topk are configured
transformer_engine/jax/permutation.py Adds pure-JAX unfused dispatch/combine, ragged-A2A parameter helpers, and local permute utilities; logic is sound and well-documented
tests/jax/test_moe_block.py Good single-device coverage; align_size>0 xfail is intentional; missing test combining num_groups + aux_loss_coeff that would expose the tokens_per_expert mismatch
tests/jax/test_distributed_moe_block.py Single EP2x FSDP2 test with gradient comparison; tolerances are wide but appropriate for bfloat16 A2A-EP
transformer_engine/jax/flax/init.py Correctly exports MoEBlock to the public flax API

Sequence Diagram

sequenceDiagram
    participant Input as Input [B,S,H]
    participant Gate as _gate (einsum)
    participant Router as _route_topk (fused_topk)
    participant AuxLoss as _compute_aux_loss
    participant Perm as _global_permute
    participant A2A_Fwd as ragged_all_to_all (fwd)
    participant LocalPerm as local_permute_after_a2a
    participant FFN as _expert_ffn (grouped_dense x3)
    participant LocalUnperm as local_unpermute_before_a2a
    participant A2A_Rev as ragged_all_to_all (rev)
    participant Combine as _global_combine
    participant Output as Output [B,S,H]

    Input->>Gate: inputs [B,S,H]
    Gate->>Router: gate_logits [B,S,E]
    Router->>AuxLoss: logits_2d (aux branch, parallel)
    Router->>Perm: sparse_probs, routing_map
    Perm->>A2A_Fwd: sorted_inputs [T*k,H] + group_sizes [E]
    A2A_Fwd->>LocalPerm: x_recv [recv_buf, H]
    LocalPerm->>FFN: sorted_x, local_group_sizes [E_local]
    FFN->>LocalUnperm: expert_outputs [recv_buf, H]
    LocalUnperm->>A2A_Rev: x_send_back
    A2A_Rev->>Combine: y_back [T*k, H]
    Combine->>Output: output [B,S,H] + aux_loss
Loading

Reviews (3): Last reviewed commit: "address greptile comments" | Re-trigger Greptile

Comment on lines 876 to 882
def moe_permute(
inp: torch.Tensor,
routing_map: torch.Tensor,
num_out_tokens: int = -1,
num_out_tokens: int,
max_token_num: int = -1,
map_type: str = "mask",
) -> Tuple[torch.Tensor, torch.Tensor]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Making num_out_tokens a required positional argument is a breaking public API change. moe_permute and moe_permute_with_probs are both re-exported from transformer_engine.pytorch, so any downstream caller that relied on the old default of -1 (e.g. moe_permute(inp, routing_map)) will now receive a TypeError at runtime with no deprecation warning. Consider keeping the default and emitting a DeprecationWarning when the caller passes a negative value, or bumping the major version and documenting the change explicitly in the release notes.

Suggested change
def moe_permute(
inp: torch.Tensor,
routing_map: torch.Tensor,
num_out_tokens: int = -1,
num_out_tokens: int,
max_token_num: int = -1,
map_type: str = "mask",
) -> Tuple[torch.Tensor, torch.Tensor]:
def moe_permute(
inp: torch.Tensor,
routing_map: torch.Tensor,
num_out_tokens: int = -1,
max_token_num: int = -1,
map_type: str = "mask",
) -> Tuple[torch.Tensor, torch.Tensor]:

Comment on lines +67 to +77
from ..permutation import (
_routing_map_to_selected_experts,
compute_ragged_all_to_all_params,
compute_reverse_ragged_all_to_all_params,
local_permute_after_a2a,
local_unpermute_before_a2a,
token_combine,
token_dispatch,
unfused_token_combine,
unfused_token_dispatch,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Import of private function across module boundary

_routing_map_to_selected_experts is imported from permutation.py by name with a leading underscore, signalling it is a private implementation detail. If this conversion helper is needed externally, it should be renamed (drop the underscore), added to __all__, and documented. As-is, a refactor inside permutation.py can silently break moe.py without any API-violation signal.

tdophung added 6 commits May 5, 2026 16:35
Signed-off-by: tdophung <tdophung@nvidia.com>
…ody single GPU vs. multi GPU

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
…e and single device initial params in the MoEBlock. Tests should pass now

Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the teddy/moe_block branch from 8a838f3 to 6aeb491 Compare May 5, 2026 23:44
pre-commit-ci Bot and others added 2 commits May 5, 2026 23:45
Signed-off-by: tdophung <tdophung@nvidia.com>
Comment on lines +427 to +457
def _compute_aux_loss(
self,
logits_2d: jnp.ndarray,
) -> Optional[jnp.ndarray]:
"""Compute the MoE auxiliary load-balancing loss.

The score-for-aux kernel has no data dependency on the main
routing kernel, so XLA can overlap them on the GPU.

``logits_2d`` should be the *full* logits tensor over the global
token batch -- under EP the caller is responsible for
:func:`jax.lax.all_gather` ing the logits before calling this so
the aux_loss formula
``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])``
sees the global ``T`` and the global ``tokens_per_expert``.
"""
if self.aux_loss_coeff <= 0.0:
return None
aux_scores, aux_routing_map = fused_topk_with_score_function(
logits_2d.astype(jnp.float32),
topk=self.num_experts_per_tok,
score_function=self.score_function,
compute_aux_scores=True,
)
aux_tokens_per_expert = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0)
return fused_moe_aux_loss(
aux_scores.astype(jnp.float32),
aux_tokens_per_expert,
topk=self.num_experts_per_tok,
coeff=self.aux_loss_coeff,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Aux loss tokens_per_expert is inconsistent with actual grouped-topk routing

When num_groups > 0 and group_topk > 0 (DeepSeek-style routing), fused_topk_with_score_function(..., compute_aux_scores=True) intentionally ignores those parameters and runs a clean standard top-k. The returned aux_routing_map therefore reflects different expert selections than the actual routing_map produced by _route_topk, causing aux_tokens_per_expert = sum(aux_routing_map, axis=0) to count a different token–expert distribution. Any user who combines num_groups > 0 + group_topk > 0 + aux_loss_coeff > 0 silently trains with a wrong auxiliary objective. The existing test_group_topk_deepseek test does not catch this because it leaves aux_loss_coeff at its default of 0.0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant