diff --git a/tests/jax/test_distributed_moe_block.py b/tests/jax/test_distributed_moe_block.py new file mode 100644 index 0000000000..8f08889953 --- /dev/null +++ b/tests/jax/test_distributed_moe_block.py @@ -0,0 +1,181 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Distributed tests for ``transformer_engine.jax.flax.MoEBlock``.""" + +import sys + +import flax.linen as nn +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from jax.sharding import Mesh, PartitionSpec + +from utils import assert_allclose, is_devices_enough + + +@pytest.fixture(autouse=True, scope="function") +def _inject_moe(request): + """Lazy-load ``MoEBlock`` only for tests marked ``triton``.""" + if not request.node.get_closest_marker("triton"): + yield + return + + from transformer_engine.jax import MeshResource, autocast + from transformer_engine.jax.flax import MoEBlock + + mod = sys.modules[__name__] + mod.MeshResource = MeshResource + mod.autocast = autocast + mod.MoEBlock = MoEBlock + yield + + +DTYPE = jnp.bfloat16 +# Must be divisible by ep*fsdp = 4 so the batch dim can be sharded over +# the full ('ep','fsdp') axis tuple under Experiment 3. +BATCH_SIZE = 4 +SEQUENCE_LENGTH = 16 +HIDDEN_SIZE = 64 +INTERMEDIATE_SIZE = 128 +NUM_EXPERTS = 8 +NUM_EXPERTS_PER_TOK = 2 + + +def _make_inputs(key: jax.Array) -> jax.Array: + return jax.random.normal(key, (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=DTYPE) + + +def _unwrap_partitioned(x): + return x.value if hasattr(x, "value") else x + + +@pytest.mark.triton +class TestDistributedMoEBlock: + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_ep2_fsdp2_matches_single_device(self, permutation_backend): + if not is_devices_enough(4): + pytest.skip("MoE distributed test requires 4 devices for EP=2 x FSDP=2.") + + key = jax.random.PRNGKey(11) + init_key, data_key = jax.random.split(key) + inputs = _make_inputs(data_key) + + base_kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + aux_loss_coeff=1e-2, + dtype=DTYPE, + ) + + single_block = MoEBlock(**base_kwargs) + + def _make_loss_and_grad(block): + """Build a jitted ``value_and_grad`` over ``(variables, x)``. + + Capturing ``block`` in a closure (so it isn't a jit input) + sidesteps having to mark it as static -- Flax modules are + registered pytrees but they carry Python-level config that + jit treats as part of the trace. + """ + + def loss_fn(variables, x): + output, aux_loss = block.apply(variables, x) + loss = jnp.mean(output.astype(jnp.float32) ** 2) + if aux_loss is not None: + loss = loss + aux_loss.astype(jnp.float32) + return loss, (output, aux_loss) + + return jax.jit(jax.value_and_grad(loss_fn, has_aux=True)) + + with autocast(enabled=False, mesh_resource=MeshResource()): + single_variables = single_block.init(init_key, inputs) + (single_loss, (single_output, single_aux)), single_grads = _make_loss_and_grad( + single_block + )(single_variables, inputs) + + devices = np.asarray(jax.devices()[:4]).reshape(2, 2) + mesh = Mesh(devices, ("ep", "fsdp")) + # FSDP-style sharding: weights are sharded on a *non-contracting* + # weight axis (gathered before the GEMM); activations stay sharded on + # the *batch* axis throughout - the same fsdp mesh axis is reused for + # both. The TE primitives' custom_partitioning rules expect activations + # FSDP-sharded on batch, so we declare ("batch", "fsdp") AND pass + # ``input_axes=("batch", None, None)`` to enforce it on the inputs to + # the block. ("embed", "fsdp") shards the weight's hidden dim, which + # is gathered inside grouped_dense's custom_partitioning before GEMM + # (no reshard of activations needed because their layout is unchanged). + logical_axis_rules = ( + ("exp", "ep"), + ("batch", "fsdp"), + ("embed", "fsdp"), + ) + # ``data_parallelism_axes=("fsdp",)`` opts in to the true-FSDP + # behavior: the ``shard_map``'s in_specs/out_specs become + # ``P(("ep","fsdp"), None, None)`` for the batch dim, so each + # device owns ``B/(ep*fsdp)`` unique tokens (no redundant compute + # across fsdp peers within an ep group). + sharded_block = MoEBlock( + expert_parallelism_axis="ep", + data_parallelism_axes=("fsdp",), + mesh=mesh, + input_axes=("batch", None, None), + **base_kwargs, + ) + + with mesh, autocast(enabled=False, mesh_resource=MeshResource(fsdp_resource="fsdp")): + with nn.logical_axis_rules(logical_axis_rules): + # ``MoEBlock`` registers params via ``with_logical_partitioning`` + # which only attaches LogicallyPartitioned metadata; the + # underlying jax.Array stays single-device unless ``init`` + # is run inside ``jax.jit`` with ``out_shardings``. Use the + # canonical Flax-Linen pattern (mirrors + # ``examples/jax/encoder/test_model_parallel_encoder.py``): + # 1. ``jax.eval_shape`` to trace abstract variables (keeps + # the LogicallyPartitioned wrappers; only the inner + # arrays become ShapeDtypeStruct); + # 2. ``nn.get_partition_spec`` to extract a tree of logical + # PartitionSpecs from those wrappers (treats + # LogicallyPartitioned as a leaf); + # 3. ``nn.logical_to_mesh_sharding`` to resolve those + # logical specs to NamedShardings via the active rules; + # 4. ``jax.jit(init, out_shardings=...)`` to actually + # place the params on-device with those shardings. + abstract_variables = jax.eval_shape(sharded_block.init, init_key, inputs) + logical_partition_spec = nn.get_partition_spec(abstract_variables) + out_shardings = nn.logical_to_mesh_sharding( + logical_partition_spec, mesh, logical_axis_rules + ) + sharded_variables = jax.jit(sharded_block.init, out_shardings=out_shardings)( + init_key, inputs + ) + (sharded_loss, (sharded_output, sharded_aux)), sharded_grads = _make_loss_and_grad( + sharded_block + )(sharded_variables, inputs) + + wi_0 = _unwrap_partitioned(sharded_variables["params"]["wi_0"]) + wi_1 = _unwrap_partitioned(sharded_variables["params"]["wi_1"]) + wo = _unwrap_partitioned(sharded_variables["params"]["wo"]) + assert wi_0.sharding.spec == PartitionSpec("ep", "fsdp", None) + assert wi_1.sharding.spec == PartitionSpec("ep", "fsdp", None) + assert wo.sharding.spec == PartitionSpec("ep", None, "fsdp") + + assert_allclose(sharded_output, single_output, dtype=DTYPE, atol=5e-2, rtol=5e-2) + assert_allclose(sharded_loss, single_loss, dtype=jnp.float32, atol=5e-2, rtol=5e-2) + assert_allclose(sharded_aux, single_aux, dtype=jnp.float32, atol=5e-2, rtol=5e-2) + + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + grad_single = _unwrap_partitioned(single_grads["params"][name]) + grad_sharded = _unwrap_partitioned(sharded_grads["params"][name]) + assert_allclose( + grad_sharded, + grad_single, + dtype=DTYPE, + atol=1e-1, + rtol=1e-1, + err_msg=f"Distributed gradient mismatch for {name}", + ) diff --git a/tests/jax/test_moe_block.py b/tests/jax/test_moe_block.py new file mode 100644 index 0000000000..e87593c9d4 --- /dev/null +++ b/tests/jax/test_moe_block.py @@ -0,0 +1,453 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Basic tests for ``transformer_engine.jax.flax.MoEBlock``. + +These tests exercise the MoEBlock on a single device (no expert parallelism) +and verify: + +* Forward pass runs end-to-end and produces the expected output shape. +* Backward pass yields finite, non-trivial parameter gradients. +* The two permutation backends (``"pure_jax"`` and ``"triton"``) produce + numerically equivalent outputs and gradients when given the same routing + decisions. +* Auxiliary load-balancing loss is returned when ``aux_loss_coeff > 0``. +* DeepSeek-style grouped top-k (``num_groups`` / ``group_topk``) runs. +* ``align_size > 0`` produces numerically-equivalent outputs to ``align_size = 0`` + for the pure-JAX backend (padding must not change the result). +""" + +import sys +from typing import Tuple + +import jax +import jax.numpy as jnp +import pytest + + +# The MoEBlock pulls in both the fused-router CUDA kernel and the Triton +# permutation kernels, so it can only run in the environment where those are +# available. We gate the test on the ``triton`` marker (the Triton permutation +# backend is stricter than the CUDA router). See ``conftest.py``. + + +@pytest.fixture(autouse=True, scope="function") +def _inject_moe(request): + """Lazy-load ``MoEBlock`` only for tests marked ``triton``.""" + if not request.node.get_closest_marker("triton"): + yield + return + + from transformer_engine.jax.flax import MoEBlock + + mod = sys.modules[__name__] + mod.MoEBlock = MoEBlock + yield + + +# ----------------------------------------------------------------------------- +# Configurations +# ----------------------------------------------------------------------------- +# +# Keep shapes small so the tests are cheap but still exercise every code path. + +DTYPE = jnp.bfloat16 +BATCH_SIZE = 2 +SEQUENCE_LENGTH = 16 +HIDDEN_SIZE = 64 +INTERMEDIATE_SIZE = 128 +NUM_EXPERTS = 8 +NUM_EXPERTS_PER_TOK = 2 + + +def _make_inputs( + key: jax.Array, batch_size: int = BATCH_SIZE, sequence_length: int = SEQUENCE_LENGTH +) -> jax.Array: + return jax.random.normal(key, (batch_size, sequence_length, HIDDEN_SIZE), dtype=DTYPE) + + +def _init_and_apply( + block, + inputs: jax.Array, + init_key: jax.Array, +) -> Tuple[dict, jax.Array, jax.Array]: + variables = block.init(init_key, inputs) + output, aux_loss = block.apply(variables, inputs) + return variables, output, aux_loss + + +def _unwrap_partitioned(x): + """Strip Flax logical-partition wrappers for numeric assertions.""" + return x.value if hasattr(x, "value") else x + + +# ----------------------------------------------------------------------------- +# Tests +# ----------------------------------------------------------------------------- + + +@pytest.mark.triton +class TestMoEBlockSingleDevice: + """Single-device smoke tests for :class:`MoEBlock`.""" + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_forward_shape_and_finite(self, permutation_backend): + key = jax.random.PRNGKey(0) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + _variables, output, aux_loss = _init_and_apply(block, inputs, init_key) + + assert ( + output.shape == inputs.shape + ), f"Unexpected output shape {output.shape} for backend {permutation_backend}" + assert output.dtype == inputs.dtype + assert jnp.all(jnp.isfinite(output)), "Output contains NaN/Inf" + assert aux_loss is None, "aux_loss should be None when aux_loss_coeff=0" + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_backward_grad_is_finite_and_nonzero(self, permutation_backend): + key = jax.random.PRNGKey(1) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + variables = block.init(init_key, inputs) + + def loss_fn(variables, inputs): + output, _ = block.apply(variables, inputs) + return jnp.mean(output.astype(jnp.float32) ** 2) + + grads = jax.grad(loss_fn)(variables, inputs) + # All trainable kernels should receive a non-trivial gradient. + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g = _unwrap_partitioned(grads["params"][name]) + assert jnp.all(jnp.isfinite(g)), f"{name} gradient has NaN/Inf" + assert jnp.any(g != 0.0), f"{name} gradient is identically zero" + + def test_pure_jax_triton_equivalence(self): + """Both permutation backends must produce the same forward + grads + under identical routing decisions. + + Since the two backends share the same routing path (TE's fused + top-k), fixing the gate kernel gives both the same routing decisions + and the remainder of the network is identical modulo the permutation + implementation, whose semantics are equivalent. + """ + key = jax.random.PRNGKey(2) + init_key, data_key = jax.random.split(key) + + base_kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + dtype=DTYPE, + ) + pure_block = MoEBlock(permutation_backend="pure_jax", **base_kwargs) + triton_block = MoEBlock(permutation_backend="triton", **base_kwargs) + inputs = _make_inputs(data_key) + + # Share a single parameter tree so routing decisions and expert + # weights are identical for both backends. + variables = pure_block.init(init_key, inputs) + + def loss_fn(block, variables, inputs): + output, _ = block.apply(variables, inputs) + return jnp.mean(output.astype(jnp.float32) ** 2), output + + (loss_pj, out_pj), grads_pj = jax.value_and_grad(loss_fn, argnums=1, has_aux=True)( + pure_block, variables, inputs + ) + (loss_tr, out_tr), grads_tr = jax.value_and_grad(loss_fn, argnums=1, has_aux=True)( + triton_block, variables, inputs + ) + + # BF16 tolerances: outputs come out of the grouped-GEMM + weighted + # sum so they accumulate error; we use ~2 ULPs worth of slack. + atol_out, rtol_out = 5e-2, 5e-2 + assert jnp.allclose( + out_pj, out_tr, atol=atol_out, rtol=rtol_out + ), f"Forward outputs differ across backends: max diff {jnp.max(jnp.abs(out_pj - out_tr))}" + assert jnp.allclose(loss_pj, loss_tr, atol=atol_out, rtol=rtol_out) + + # The two backends share the routing path (same fused top-k) and + # the same expert FFN; the only difference is the order of the + # gather + scatter ops in dispatch/combine. Under bf16 with these + # small shapes, observed grad max-abs-diff is on the order of a + # few-units-of-bf16-eps (~1e-2). 5e-2 / 5e-2 leaves headroom for + # accumulation jitter without masking real divergence. If this + # tightens too far on a particular GPU, print + # ``jnp.max(jnp.abs(g_pj - g_tr))`` from the failing assertion + # and bump to the next safe value with a comment recording the + # measured gap. + atol_grad, rtol_grad = 5e-2, 5e-2 + for name in ("gate_kernel", "wi_0", "wi_1", "wo"): + g_pj = _unwrap_partitioned(grads_pj["params"][name]) + g_tr = _unwrap_partitioned(grads_tr["params"][name]) + assert jnp.allclose(g_pj, g_tr, atol=atol_grad, rtol=rtol_grad), ( + f"Gradient for {name} differs across backends: max diff" + f" {jnp.max(jnp.abs(g_pj - g_tr))} (atol={atol_grad}," + f" rtol={rtol_grad})" + ) + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_aux_loss_returned(self, permutation_backend): + key = jax.random.PRNGKey(3) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + aux_loss_coeff=1e-2, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + _variables, output, aux_loss = _init_and_apply(block, inputs, init_key) + + assert output.shape == inputs.shape + assert aux_loss is not None, "aux_loss should be returned when coeff > 0" + assert aux_loss.shape == (), "aux_loss should be a scalar" + assert jnp.isfinite(aux_loss) + # With uniform-ish routing the loss should be small-positive, not huge. + assert jnp.abs(aux_loss) < 1e2 + + def test_aux_loss_uses_real_routing_under_group_topk(self): + """Regression test for PR #2912 review (greptile P1). + + Under DeepSeek-style ``num_groups`` / ``group_topk`` routing, + the auxiliary load-balancing loss must be computed using the + per-expert token counts from the *real* routing_map (post + grouping), not from the clean top-k that the + ``compute_aux_scores=True`` kernel returns. Otherwise the aux + objective trains against the wrong distribution. + + We compute three values: + * ``corrected_ref`` -- ``fused_moe_aux_loss(aux_scores, + tokens_from_real_routing_map, ...)`` (what the block + should produce after the fix). + * ``buggy_ref`` -- ``fused_moe_aux_loss(aux_scores, + tokens_from_aux_routing_map, ...)`` (what the block used + to produce before the fix). + * ``block_aux_loss`` -- what the block actually produces. + + Block must match the corrected reference. We also assert that + the corrected and buggy references differ for this config so + the test is not vacuously satisfied by them coinciding. + """ + from transformer_engine.jax.router import ( + fused_moe_aux_loss, + fused_topk_with_score_function, + ) + + key = jax.random.PRNGKey(7) + init_key, data_key = jax.random.split(key) + + # Pick a config that *reliably* exercises grouped-vs-clean + # divergence: with ``group_topk=1`` only ONE group's experts + # can be selected by grouped routing, so the routing diverges + # from a plain top-k whenever the global top-K experts are + # spread across multiple groups (which is almost always the + # case for random init + ``num_experts_per_tok > 1``). + num_groups = 2 + group_topk = 1 + aux_loss_coeff = 1e-2 + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend="pure_jax", + score_function="sigmoid", + num_groups=num_groups, + group_topk=group_topk, + aux_loss_coeff=aux_loss_coeff, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + variables = block.init(init_key, inputs) + _output, block_aux_loss = block.apply(variables, inputs) + + assert block_aux_loss is not None + + # Reproduce the gating GEMM and routing externally so we can + # build the references against the same logits the block sees. + gate_kernel = _unwrap_partitioned(variables["params"]["gate_kernel"]) + gate_kernel = gate_kernel.astype(inputs.dtype) + logits = jnp.einsum("bsh,he->bse", inputs, gate_kernel) + logits_2d = logits.reshape(-1, NUM_EXPERTS) + + # Real routing (with grouping). This is what _route_topk + # would produce inside the block. + _, real_routing_map = fused_topk_with_score_function( + logits_2d, + topk=NUM_EXPERTS_PER_TOK, + score_function="sigmoid", + num_groups=num_groups, + group_topk=group_topk, + ) + real_tokens = jnp.sum(real_routing_map.astype(jnp.int32), axis=0) + + # Aux scores + the (clean topk) aux_routing_map that the old + # buggy code used for tokens_per_expert. + aux_scores, aux_routing_map = fused_topk_with_score_function( + logits_2d.astype(jnp.float32), + topk=NUM_EXPERTS_PER_TOK, + score_function="sigmoid", + compute_aux_scores=True, + ) + buggy_tokens = jnp.sum(aux_routing_map.astype(jnp.int32), axis=0) + + corrected_ref = fused_moe_aux_loss( + aux_scores.astype(jnp.float32), + real_tokens, + topk=NUM_EXPERTS_PER_TOK, + coeff=aux_loss_coeff, + ) + buggy_ref = fused_moe_aux_loss( + aux_scores.astype(jnp.float32), + buggy_tokens, + topk=NUM_EXPERTS_PER_TOK, + coeff=aux_loss_coeff, + ) + + # Sanity: the test config must actually exercise the bug + # (otherwise both references coincide and the assertion below + # would silently pass even with the old code). + assert not jnp.allclose(real_tokens, buggy_tokens), ( + "Test config does not exercise grouped-topk vs clean-topk" + " divergence; pick a config where they differ" + ) + + assert jnp.allclose( + block_aux_loss, corrected_ref, atol=1e-5, rtol=1e-5 + ), f"Block aux_loss {block_aux_loss} does not match real-routing reference {corrected_ref}" + # The corrected and buggy refs can be numerically close + # (only the mis-routed tokens contribute to the difference), + # so assert that the block is *strictly closer* to the + # corrected ref than to the buggy one. This catches the + # regression robustly even when the absolute gap between + # corrected_ref and buggy_ref is sub-tolerance. + diff_to_corrected = jnp.abs(block_aux_loss - corrected_ref) + diff_to_buggy = jnp.abs(block_aux_loss - buggy_ref) + gap = jnp.abs(corrected_ref - buggy_ref) + assert diff_to_corrected < diff_to_buggy, ( + f"Block aux_loss {block_aux_loss} is closer to the *old" + f" buggy* reference ({buggy_ref}, diff={diff_to_buggy})" + f" than to the corrected reference ({corrected_ref}," + f" diff={diff_to_corrected}); the regression has" + f" reappeared. corrected-buggy gap = {gap}" + ) + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_group_topk_deepseek(self, permutation_backend): + """Exercise DeepSeek-style grouped top-k routing.""" + key = jax.random.PRNGKey(4) + init_key, data_key = jax.random.split(key) + + # num_groups must divide num_experts. + num_groups = 4 + group_topk = 2 + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + score_function="sigmoid", + num_groups=num_groups, + group_topk=group_topk, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + _variables, output, _aux_loss = _init_and_apply(block, inputs, init_key) + + assert output.shape == inputs.shape + assert jnp.all(jnp.isfinite(output)) + + def test_align_size_equivalence_pure_jax(self, monkeypatch): + """For the pure-JAX backend, ``align_size > 0`` must not change the + numerical output of the forward pass: padding tokens contribute zero + to every expert GEMM output (their input rows are zeros) and are + stripped before the weighted sum. + + Why the env knob: the V1 TE grouped GEMM FFI asserts + ``sum(group_sizes) == M`` at + ``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``. With + ``align_size > 0`` the pure-JAX backend produces a buffer where + ``M >= sum(group_sizes)`` (the slack is structural padding for + JIT). The V2 grouped GEMM relaxes that assertion to + ``M >= sum(group_sizes)`` and is selected when + ``NVTE_JAX_ENFORCE_V2_GROUPED_GEMM=1``. If V2 isn't supported on + this hardware / for this dtype, the dispatch raises a + ``RuntimeError`` whose message is matched here so the test + ``skip``-s instead of failing. + """ + monkeypatch.setenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "1") + + key = jax.random.PRNGKey(5) + init_key, data_key = jax.random.split(key) + + base_kwargs = dict( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend="pure_jax", + dtype=DTYPE, + ) + block_no_pad = MoEBlock(align_size=0, **base_kwargs) + block_pad = MoEBlock(align_size=16, **base_kwargs) + inputs = _make_inputs(data_key) + + try: + variables = block_no_pad.init(init_key, inputs) + out_no_pad, _ = block_no_pad.apply(variables, inputs) + out_pad, _ = block_pad.apply(variables, inputs) + except RuntimeError as exc: + if "V2 grouped GEMM is not supported" in str(exc): + pytest.skip(f"V2 grouped GEMM unavailable on this hardware: {exc}") + raise + + assert jnp.allclose(out_no_pad, out_pad, atol=5e-2, rtol=5e-2), ( + "align_size > 0 must not change pure_jax forward output; max diff" + f" {jnp.max(jnp.abs(out_no_pad - out_pad))}" + ) + + @pytest.mark.parametrize("permutation_backend", ["pure_jax", "triton"]) + def test_jit_and_determinism(self, permutation_backend): + """The block must be JIT-compilable and produce a deterministic + forward pass across repeat calls with the same params.""" + key = jax.random.PRNGKey(6) + init_key, data_key = jax.random.split(key) + + block = MoEBlock( + num_experts=NUM_EXPERTS, + num_experts_per_tok=NUM_EXPERTS_PER_TOK, + intermediate_size=INTERMEDIATE_SIZE, + permutation_backend=permutation_backend, + dtype=DTYPE, + ) + inputs = _make_inputs(data_key) + variables = block.init(init_key, inputs) + + @jax.jit + def forward(variables, inputs): + return block.apply(variables, inputs)[0] + + out_a = forward(variables, inputs) + out_b = forward(variables, inputs) + assert jnp.array_equal(out_a, out_b), "JITted forward is non-deterministic" diff --git a/transformer_engine/common/util/multi_stream.cpp b/transformer_engine/common/util/multi_stream.cpp index 6b19f36741..ec341abc68 100644 --- a/transformer_engine/common/util/multi_stream.cpp +++ b/transformer_engine/common/util/multi_stream.cpp @@ -12,6 +12,7 @@ #include #include +#include #include #include "cuda_runtime.h" @@ -19,18 +20,54 @@ namespace transformer_engine::detail { +namespace { + +// CUDA streams and events are device-bound: a stream / event created +// on device A cannot be recorded into / waited on from device B +// (CUDA returns ``cudaErrorInvalidResourceHandle``). The previous +// implementation used ``std::call_once`` to lazily create one +// process-global vector of streams + one of events, which works for +// the single-device case (PyTorch eager / single-host single-device +// JAX) but breaks for single-process *multi*-device JAX: the first +// worker thread to win the ``call_once`` would create streams / +// events on its own device, and subsequent calls from other devices +// would receive those same handles and fail at ``cudaEventRecord``. +// +// We now key the cache on the active CUDA device. Each device gets +// its own ``num_compute_streams`` streams and events, created lazily +// the first time a thread on that device asks for one. +template +auto& per_device_pool(CreateFn&& create) { + static std::mutex mu; + using PoolT = decltype(std::vector{create()}); + static std::unordered_map pools; + int device; + NVTE_CHECK_CUDA(cudaGetDevice(&device)); + std::lock_guard lock(mu); + auto it = pools.find(device); + if (it == pools.end()) { + const size_t num_streams = nvte_get_num_compute_streams(); + PoolT v; + v.reserve(num_streams); + for (size_t i = 0; i < num_streams; i++) { + v.push_back(create()); + } + it = pools.emplace(device, std::move(v)).first; + } + return it->second; +} + +} // namespace + cudaStream_t get_compute_stream(int idx) { const size_t num_streams = nvte_get_num_compute_streams(); NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx, ", but there are ", num_streams, " streams)"); - static std::vector streams(num_streams); - static std::once_flag stream_init_flag; - auto init = [&]() { - for (size_t i = 0; i < num_streams; i++) { - NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&streams[i], cudaStreamNonBlocking, -1)); - } - }; - std::call_once(stream_init_flag, init); + auto& streams = per_device_pool([] { + cudaStream_t s; + NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&s, cudaStreamNonBlocking, -1)); + return s; + }); return streams[idx]; } @@ -38,14 +75,11 @@ cudaEvent_t get_compute_stream_event(int idx) { const size_t num_streams = nvte_get_num_compute_streams(); NVTE_CHECK(0 <= idx && idx < num_streams, "Invalid compute stream (requested idx ", idx, ", but there are ", num_streams, " streams)"); - static std::vector events(num_streams); - static std::once_flag event_init_flag; - auto init = [&]() { - for (size_t i = 0; i < num_streams; i++) { - NVTE_CHECK_CUDA(cudaEventCreate(&events[i])); - } - }; - std::call_once(event_init_flag, init); + auto& events = per_device_pool([] { + cudaEvent_t e; + NVTE_CHECK_CUDA(cudaEventCreate(&e)); + return e; + }); return events[idx]; } diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4ff6d07986..94b2de9573 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2024,9 +2024,14 @@ def grouped_gemm_copy_group_sizes( return out -@cache def _should_enforce_v2_grouped_gemm() -> bool: - """Read NVTE_JAX_ENFORCE_V2_GROUPED_GEMM once per process (cached).""" + """Read NVTE_JAX_ENFORCE_V2_GROUPED_GEMM. + + Not cached so tests can flip the env var with ``monkeypatch.setenv`` + and have it picked up on the next call. This is called only on + grouped-GEMM dispatch (not in any tight loop), so the per-call + ``getenv`` cost is negligible. + """ val = os.getenv("NVTE_JAX_ENFORCE_V2_GROUPED_GEMM", "0") try: return bool(int(val)) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 6ca907032c..8a807cbdcc 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -1157,12 +1157,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type cudaStreamSynchronize(stream); } size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + // Allow callers to pass an LHS/RHS that is at least as large as the active + // ragged region (sum_group_sizes). This supports ragged-all-to-all flows + // where the recv buffer is over-allocated to a worst-case size and only + // the first sum_group_sizes rows along the ragged dim are populated; the + // trailing slack rows are not consumed by the per-group GEMMs (which key + // off group_sizes). if (!is_rhs_ragged) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); + NVTE_CHECK(sum_group_sizes <= m, "Unexpected group_sizes! sum(group_sizes)=", sum_group_sizes, + " must be <= M = ", m); } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); + NVTE_CHECK(sum_group_sizes <= k, "Unexpected group_sizes! sum(group_sizes)=", sum_group_sizes, + " must be <= K = ", k); } } diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 650139a61c..871abb5634 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -383,9 +383,17 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty cudaStreamSynchronize(stream); size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - NVTE_CHECK(m == sum_group_sizes || input_dims[0] == sum_group_sizes, - "Unexpected group_sizes! Got %zu (M=%zu, input_dims[0] = %zu)", sum_group_sizes, m, - input_dims[0]); + // Allow callers to pass an input that is at least as large as the active + // ragged region (sum_group_sizes). This supports ragged-all-to-all flows + // where the recv buffer is over-allocated to a worst-case size and only the + // first sum_group_sizes rows are populated; the trailing slack rows are + // simply not quantized (and not consumed by the downstream grouped GEMM + // which is also keyed on group_sizes). + // For flatten_axis==1, m == input_dims[0]; for flatten_axis>1, the per-group + // tile is dim_list_host[i] * non_group_m, so the binding dim is input_dims[0]. + NVTE_CHECK(sum_group_sizes <= input_dims[0], + "Unexpected group_sizes! sum(group_sizes)=%zu must be <= input_dims[0]=%zu (M=%zu)", + sum_group_sizes, input_dims[0], m); if (is_delayed_scaling) { NVTE_CHECK(amaxs->dimensions()[0] == num_groups, "Unexpected amax size, Expected ", num_groups, diff --git a/transformer_engine/jax/flax/__init__.py b/transformer_engine/jax/flax/__init__.py index 92a968f061..0cd7835bcf 100644 --- a/transformer_engine/jax/flax/__init__.py +++ b/transformer_engine/jax/flax/__init__.py @@ -9,6 +9,7 @@ make_dot_general_cls, make_grouped_dense_cls, ) +from .moe import MoEBlock from .transformer import extend_logical_axis_rules from .transformer import DotProductAttention, MultiHeadAttention, RelativePositionBiases from .transformer import TransformerLayer, TransformerLayerType @@ -18,6 +19,7 @@ "LayerNorm", "LayerNormDenseGeneral", "LayerNormMLP", + "MoEBlock", "wrap_function_in_te_state_module", "make_dot_general_cls", "make_grouped_dense_cls", diff --git a/transformer_engine/jax/flax/moe.py b/transformer_engine/jax/flax/moe.py new file mode 100644 index 0000000000..712499c2cd --- /dev/null +++ b/transformer_engine/jax/flax/moe.py @@ -0,0 +1,1135 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Flax Linen MoEBlock for TransformerEngine JAX. + +This module exposes :class:`MoEBlock`, a self-contained Flax Linen MoE layer +that wires together TE's fused router, a selectable token-dispatch backend +(``pure_jax`` or ``triton``), TE's ``grouped_dense``, and an +optional ragged-all-to-all (A2A / A2Av) expert-parallelism strategy. + +Architecture +------------ + +The MoEBlock is decomposed into orthogonal stages so the EP wrapper can +inject collectives between them: + +* ``_route``: gate logits -> top-k routing decisions (+ aux loss). +* ``_global_permute``: scatter tokens to experts; produces + ``[num_tokens*topk + maybe_padding, hidden]`` and + per-expert ``group_sizes`` of length ``num_experts``. +* ``_expert_ffn``: three ``grouped_dense`` calls + activation. Operates + on whatever ``(rows, group_sizes, n_groups)`` it is + handed -- agnostic to whether ``n_groups`` is the + global expert count (no-EP) or the local expert + count (A2A-EP). +* ``_global_combine``: inverse of ``_global_permute`` -- gather + weighted + sum across top-k experts. + +Two top-level forward variants compose those stages: + +* ``_forward_no_ep``: route -> permute -> ffn -> combine. Each TE + primitive's ``custom_partitioning`` rule handles + DP / FSDP / TP automatically. +* ``_forward_a2a_ep``: wraps the body in :func:`jax.shard_map` and inserts + ``all_gather(group_sizes)`` + forward + ``ragged_all_to_all`` + local permute around the + FFN, plus their inverses afterwards. This is the + only place ``shard_map`` is used; A2A is the + canonical EP strategy because the in-flight NCCL + EP component will require this same data layout. + +Note on ``align_size > 0`` +-------------------------- + +Both permutation backends pad each expert's group to a multiple of +``align_size`` when requested, which is what CUBLASLt's grouped GEMM wants +for FP8 shape selection. The pure-JAX backend additionally appends a +zero-input padding tail to keep the buffer statically sized for JIT, so +``sum(group_sizes) <= sorted_inputs.shape[0]`` strictly. TE's +``grouped_dense`` FFI today asserts ``m == sum(group_sizes)`` at +``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``; relaxing that +check to ``m >= sum(group_sizes)`` (the kernel itself only iterates over +``sum(group_sizes)`` rows via ``nvte_multi_tensor_gemm``) is the cleanest +way to support ``align_size > 0`` end-to-end. Until that lands the +``align_size > 0`` tests stay xfail. +""" + +from functools import partial +from typing import Any, Callable, NewType, Optional, Tuple, Union + +import jax +import jax.numpy as jnp +from flax import linen as nn, struct as flax_struct +from jax.sharding import PartitionSpec as P + +from ..dense import grouped_dense +from ..permutation import ( + routing_map_to_selected_experts, + compute_ragged_all_to_all_params, + compute_reverse_ragged_all_to_all_params, + local_permute_after_a2a, + local_unpermute_before_a2a, + PureJaxPermState, + pure_jax_token_combine, + pure_jax_token_dispatch, + token_combine, + token_dispatch, +) +from ..quantize import noop_quantizer_set +from ..router import ScoreFunction, fused_moe_aux_loss, fused_topk_with_score_function +from ..sharding import with_sharding_constraint_by_logical_axes +from .module import TransformerEngineBase, _convert_to_activation_function + +PRNGKey = Any +Shape = Tuple[int, ...] +DType = NewType("DType", jnp.dtype) +Array = NewType("Array", jnp.ndarray) +Initializer = Callable[[PRNGKey, Shape, DType], Array] + + +__all__ = ["GlobalPermuteResult", "MoEBlock"] + + +# ============================================================================= +# GlobalPermuteResult +# ============================================================================= +# +# Output of :meth:`MoEBlock._global_permute`. Carried as a pytree (so it +# crosses ``jax.shard_map`` / ``jax.value_and_grad`` boundaries +# transparently) and consumed by :meth:`MoEBlock._global_combine`. The +# fields populated depend on the permutation backend; the unused fields +# stay ``None``. +# +# Per-backend payloads (anything else is ``None``): +# pure_jax: ``perm_state``, ``routing_weights`` +# triton: ``row_id_map``, ``pad_offsets``, ``merging_probs`` + + +@flax_struct.dataclass +class GlobalPermuteResult: + """Result of :meth:`MoEBlock._global_permute`.""" + + sorted_inputs: jnp.ndarray + group_sizes: jnp.ndarray + perm_state: Optional[PureJaxPermState] = None + routing_weights: Optional[jnp.ndarray] = None + row_id_map: Optional[jnp.ndarray] = None + pad_offsets: Optional[jnp.ndarray] = None + merging_probs: Optional[jnp.ndarray] = None + backend: str = flax_struct.field(pytree_node=False, default="pure_jax") + + +# ============================================================================= +# MoEBlock +# ============================================================================= + + +class MoEBlock(TransformerEngineBase): + """Mixture-of-Experts Flax Linen block. + + Encapsulates the full MoE forward pass: gate projection, fused top-k + routing, optional auxiliary load-balancing loss, token dispatch, + per-expert two-layer FFN via grouped GEMMs, activation, token combine, + and optional ragged-all-to-all expert parallelism. + + Two permutation backends are pluggable via ``permutation_backend``: + + * ``"pure_jax"`` (default) -- argsort-based + :func:`~transformer_engine.jax.permutation.pure_jax_token_dispatch` / + :func:`~transformer_engine.jax.permutation.pure_jax_token_combine`. + Faster than Triton in profiling for DeepSeek-style configs. + * ``"triton"`` -- TE's fused + :func:`~transformer_engine.jax.permutation.token_dispatch` / + :func:`~transformer_engine.jax.permutation.token_combine` Triton + kernels. + + Expert parallelism (``expert_parallelism_axis is not None``) uses the + **ragged-all-to-all** EP strategy (a.k.a. A2Av): each shard routes its + own tokens globally over all experts, then a forward + ``ragged_all_to_all`` exchanges per-expert chunks so each shard ends up + holding only the tokens for its local experts; after the FFN a reverse + ``ragged_all_to_all`` returns each shard's outputs to it. This matches + the layout the in-flight NCCL EP component expects. + + Parameters + ---------- + num_experts : int + Total number of experts. + num_experts_per_tok : int + Top-k value (number of experts each token is routed to). + intermediate_size : int + Per-expert FFN hidden dim. + + activation_type : str + FFN activation applied to the gate projection. Paired with the up + projection in the SwiGLU-style ``act(wi_0) * wi_1`` product. + Resolved via :func:`flax.linen.` (``"silu"``, ``"gelu"``, + ``"relu"``, ``"swish"``, ...) plus ``"linear"`` for identity. + + score_function : str or ScoreFunction + ``"softmax"`` (default) or ``"sigmoid"`` for + :func:`fused_topk_with_score_function`. + use_pre_softmax : bool + Apply softmax before top-k when ``score_function="softmax"``. + num_groups : int + Number of routing groups for grouped top-k (DeepSeek). ``<=0`` + disables. + group_topk : int + Top-k at the group level. ``<=0`` disables. + scaling_factor : float + Scaling factor applied to output probs. + use_expert_bias : bool + If ``True``, registers a learnable ``expert_bias`` parameter of + shape ``[num_experts]`` and passes it to the fused router. The + router primitive validates that this is paired with + ``score_function="sigmoid"``. + aux_loss_coeff : float + If ``> 0``, compute and return the MoE auxiliary load-balancing + loss scalar via :func:`fused_moe_aux_loss`. ``0`` disables. + + gate_kernel_axes : tuple[str, ...] + Logical partitioning axes for the gate kernel of shape + ``[hidden, num_experts]``. + wi_kernel_axes : tuple[str, ...] + Logical partitioning axes for the ``wi_0`` and ``wi_1`` kernels of + shape ``[num_experts, hidden, intermediate]``. Default + ``("exp", "embed", "mlp")``. + wo_kernel_axes : tuple[str, ...] + Logical partitioning axes for the ``wo`` kernel of shape + ``[num_experts, intermediate, hidden]``. Default + ``("exp", "mlp", "embed")``. + input_axes : tuple[str, ...] + Logical axes used to constrain the input activation sharding at the + block boundary. ``()`` (default) means no constraint. + + expert_parallelism_axis : Optional[str] + Mesh axis along which experts are split. When set, the forward + pass is wrapped in :func:`jax.shard_map` that implements the + ragged-all-to-all EP strategy. When ``None`` (default), no + ``shard_map`` wrapper is used; each TE primitive's + ``custom_partitioning`` rule handles DP / FSDP / TP automatically. + data_parallelism_axes : tuple[str, ...] + Additional mesh axes that the input *batch* dim is sharded over + IN ADDITION to ``expert_parallelism_axis``. Setting this to e.g. + ``("fsdp",)`` makes the ``shard_map`` ``in_specs`` for the batch + dim become ``P(("ep", "fsdp"), None, None)`` -- giving each + device a unique slice of the batch (true FSDP) instead of + replicating the per-ep-shard batch across fsdp peers. + Routing is unaffected: ``axis_index("ep")`` still controls the + ragged-all-to-all; the extra fsdp peers within an ep group send + and receive their own batch slices in lockstep. Default ``()`` + preserves legacy ZeRO-1-style behavior (activations replicated + on fsdp within an ep group). + tensor_parallelism_axis : Optional[str] + Mesh axis for tensor parallelism on the FFN intermediate dim. When + set, the output of the ``wo`` grouped GEMM is ``psum_scatter`` ed + along this axis. + + permutation_backend : str + ``"pure_jax"`` (default) or ``"triton"``. + align_size : int + Alignment for per-expert group sizes after padding. ``0`` disables + padding (the only supported configuration end-to-end today). ``>0`` + is required for quantized TE grouped GEMM whose recipe-specific + alignment must divide ``align_size``; see the module docstring for + the FFI assertion that currently blocks ``>0`` for both backends. + + dtype : jnp.dtype + Compute and parameter dtype. + kernel_init : Initializer + Initializer for all kernels (gate + per-expert FFN). Defaults to + ``variance_scaling(1.0, 'fan_in', 'truncated_normal')`` (Flax + convention). + use_bias : bool + If ``True``, registers per-expert FFN biases ``wi_0_bias``, + ``wi_1_bias``, ``wo_bias``. + """ + + # Architecture + num_experts: int = 8 + num_experts_per_tok: int = 2 + intermediate_size: int = 2048 + activation_type: str = "silu" + + # Routing + score_function: Union[str, ScoreFunction] = "softmax" + use_pre_softmax: bool = False + num_groups: int = -1 + group_topk: int = -1 + scaling_factor: float = 1.0 + use_expert_bias: bool = False + aux_loss_coeff: float = 0.0 + + # Sharding + gate_kernel_axes: Tuple[Optional[str], ...] = () + wi_kernel_axes: Tuple[Optional[str], ...] = ("exp", "embed", "mlp") + wo_kernel_axes: Tuple[Optional[str], ...] = ("exp", "mlp", "embed") + input_axes: Tuple[Optional[str], ...] = () + + # Parallelism + expert_parallelism_axis: Optional[str] = None + data_parallelism_axes: Tuple[str, ...] = () + tensor_parallelism_axis: Optional[str] = None + # ``jax.sharding.Mesh`` to use when ``expert_parallelism_axis`` is set. + # Required for the ``shard_map`` wrapper; ignored otherwise. + mesh: Optional[Any] = None + + # Permutation + permutation_backend: str = "pure_jax" + align_size: int = 0 + + # Dtypes / init / misc + dtype: DType = jnp.float32 + kernel_init: Optional[Initializer] = None + bias_init: Initializer = nn.initializers.zeros + expert_bias_init: Initializer = nn.initializers.zeros + use_bias: bool = False + + def __post_init__(self): + if self.kernel_init is None: + object.__setattr__( + self, + "kernel_init", + nn.initializers.variance_scaling( + 1.0, "fan_in", "truncated_normal", dtype=self.dtype + ), + ) + if self.permutation_backend not in ("pure_jax", "triton"): + raise ValueError( + "permutation_backend must be 'pure_jax' or 'triton'," + f" got {self.permutation_backend!r}" + ) + super().__post_init__() + + # ------------------------------------------------------------------ + # Entry point + # ------------------------------------------------------------------ + + @nn.compact + def __call__(self, inputs: Array) -> Tuple[Array, Optional[Array]]: + """Run the MoE forward pass. + + Parameters + ---------- + inputs : jnp.ndarray + Input tensor of shape ``[batch, sequence, hidden]``. + + Returns + ------- + output : jnp.ndarray + Output tensor of shape ``[batch, sequence, hidden]``. + aux_loss : Optional[jnp.ndarray] + Scalar auxiliary load-balancing loss when + ``aux_loss_coeff > 0``, else ``None``. + """ + assert ( + inputs.ndim == 3 + ), f"MoEBlock expects [batch, sequence, hidden] input, got shape {inputs.shape}" + inputs = with_sharding_constraint_by_logical_axes(inputs, self.input_axes) + + _, _, hidden_size = inputs.shape + + # Param registrations are inlined here (not in a helper) so each + # ``self.param`` lives close to the rest of the entry point. + # Note: under EP the FFN weights and ``expert_bias`` are + # consumed *inside* a ``shard_map`` body. Flax's ``self.param`` + # must run OUTSIDE any JAX transform that would alter the + # variable scope (``shard_map`` does), so the registrations stay + # here in ``__call__`` and the values are passed down explicitly + # via ``in_specs``. ``_gate`` is called outside ``shard_map`` in + # both paths, so its kernel is registered inline inside + # ``_gate`` itself rather than here. + + gate_logits = self._gate(inputs) + + wi_0 = self.param( + "wi_0", + nn.with_logical_partitioning(self.kernel_init, self.wi_kernel_axes), + (self.num_experts, hidden_size, self.intermediate_size), + self.dtype, + ) + wi_1 = self.param( + "wi_1", + nn.with_logical_partitioning(self.kernel_init, self.wi_kernel_axes), + (self.num_experts, hidden_size, self.intermediate_size), + self.dtype, + ) + wo = self.param( + "wo", + nn.with_logical_partitioning(self.kernel_init, self.wo_kernel_axes), + (self.num_experts, self.intermediate_size, hidden_size), + self.dtype, + ) + wi_0_bias = wi_1_bias = wo_bias = None + if self.use_bias: + wi_0_bias = self.param( + "wi_0_bias", + nn.with_logical_partitioning(self.bias_init, ("exp", "mlp")), + (self.num_experts, self.intermediate_size), + self.dtype, + ) + wi_1_bias = self.param( + "wi_1_bias", + nn.with_logical_partitioning(self.bias_init, ("exp", "mlp")), + (self.num_experts, self.intermediate_size), + self.dtype, + ) + wo_bias = self.param( + "wo_bias", + nn.with_logical_partitioning(self.bias_init, ("exp", "embed")), + (self.num_experts, hidden_size), + self.dtype, + ) + expert_bias = None + if self.use_expert_bias: + expert_bias = self.param( + "expert_bias", + nn.with_logical_partitioning(self.expert_bias_init, ("exp",)), + (self.num_experts,), + self.dtype, + ) + + if self.expert_parallelism_axis is None: + output, aux_loss = self._forward_no_ep( + inputs, + gate_logits, + wi_0=wi_0, + wi_1=wi_1, + wo=wo, + wi_0_bias=wi_0_bias, + wi_1_bias=wi_1_bias, + wo_bias=wo_bias, + expert_bias=expert_bias, + ) + else: + output, aux_loss = self._forward_a2a_ep( + inputs, + gate_logits, + wi_0=wi_0, + wi_1=wi_1, + wo=wo, + wi_0_bias=wi_0_bias, + wi_1_bias=wi_1_bias, + wo_bias=wo_bias, + expert_bias=expert_bias, + ) + + if self.aux_loss_coeff <= 0.0: + aux_loss = None + return output, aux_loss + + # ------------------------------------------------------------------ + # Gate + # ------------------------------------------------------------------ + + def _gate(self, inputs: jnp.ndarray) -> jnp.ndarray: + """Linear gate projection ``inputs @ gate_kernel``. + + Kept as a plain ``einsum`` (not ``DenseGeneral``) so it composes + cleanly with the EP shard_map: the gate runs in the outer + (pre-shard_map) scope and its output passes through the + ``shard_map`` boundary unchanged. Because the gate runs outside + any ``shard_map`` body in both EP and no-EP forwards, the + ``gate_kernel`` parameter is registered inline here. + + The gating GEMM is intentionally kept in ``self.dtype`` (typically + ``bfloat16``) and is **not** autocast to FP8 even when the caller + wraps the block in :func:`transformer_engine.jax.autocast`. Two + reasons: (1) the GEMM is tiny (``H * E`` with ``E`` small) and + contributes well under 1% of the block's compute, so quantization + savings are marginal; (2) the resulting logits feed a top-k + + softmax (or sigmoid) routing decision that is sensitive to + quantization noise -- routing flips at low-confidence tokens + could materially hurt model quality. To override, wrap the call + site in your own ``autocast`` and manually replace this method. + """ + hidden_size = inputs.shape[-1] + gate_kernel = self.param( + "gate_kernel", + nn.with_logical_partitioning(self.kernel_init, self.gate_kernel_axes), + (hidden_size, self.num_experts), + self.dtype, + ) + kernel = gate_kernel.astype(inputs.dtype) + return jnp.einsum("bsh,he->bse", inputs, kernel) + + # ------------------------------------------------------------------ + # Route + # ------------------------------------------------------------------ + # + # The router is split into two pieces so the EP path can compute + # aux_loss over global (cross-shard) statistics without re-running + # the main top-k path. ``_route_topk`` returns the per-token routing + # decisions (used by ``_global_permute``) and ``_compute_aux_loss`` + # returns the scalar load-balancing loss given the (possibly + # gathered) logits. + + def _route_topk( + self, + logits_2d: jnp.ndarray, + expert_bias: Optional[jnp.ndarray], + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Run the fused router top-k selection.""" + sparse_probs, routing_map = fused_topk_with_score_function( + logits_2d, + topk=self.num_experts_per_tok, + use_pre_softmax=self.use_pre_softmax, + num_groups=self.num_groups, + group_topk=self.group_topk, + scaling_factor=self.scaling_factor, + score_function=self.score_function, + expert_bias=expert_bias, + ) + sparse_probs = sparse_probs.astype(self.dtype) + return sparse_probs, routing_map + + def _compute_aux_loss( + self, + logits_2d: jnp.ndarray, + tokens_per_expert: jnp.ndarray, + ) -> Optional[jnp.ndarray]: + """Compute the MoE auxiliary load-balancing loss. + + The score-for-aux kernel reads only ``logits_2d`` and the final + reduction reads only the (already-computed) ``tokens_per_expert``, + so the aux scores can run concurrently with the main routing + path on the GPU. + + ``logits_2d`` should be the *full* logits tensor over the global + token batch -- under EP the caller is responsible for + :func:`jax.lax.all_gather` ing the logits before calling this so + the aux_loss formula + ``loss = (E * coeff / (k * T^2)) * sum_i(sum_t(probs[t,i]) * tokens[i])`` + sees the global ``T``. + + ``tokens_per_expert`` must be the per-expert token-assignment + count from the *actual* routing decision -- i.e. derived from + ``_route_topk``'s ``routing_map``, not recomputed from a clean + top-k. This matters under DeepSeek-style routing + (``num_groups > 0`` / ``group_topk > 0``) where the + post-grouping routing differs from a plain top-k. Under EP the + caller is responsible for summing over all (ep + dp) shards + first so the count is global. + """ + if self.aux_loss_coeff <= 0.0: + return None + # The "compute_aux_scores=True" kernel intentionally ignores + # num_groups/group_topk/expert_bias and returns the dense + # post-score-function scores over all experts. Those scores are + # what the aux-loss formula expects (raw scoring, no grouping + # bias); the routing decisions used for ``tokens_per_expert`` + # come from the caller-supplied real ``routing_map``. + aux_scores, _ = fused_topk_with_score_function( + logits_2d.astype(jnp.float32), + topk=self.num_experts_per_tok, + score_function=self.score_function, + compute_aux_scores=True, + ) + return fused_moe_aux_loss( + aux_scores.astype(jnp.float32), + tokens_per_expert.astype(jnp.int32), + topk=self.num_experts_per_tok, + coeff=self.aux_loss_coeff, + ) + + # ------------------------------------------------------------------ + # Global permute (route -> token dispatch) + # ------------------------------------------------------------------ + + def _global_permute( + self, + inputs_2d: jnp.ndarray, + sparse_probs: jnp.ndarray, + routing_map: jnp.ndarray, + ) -> GlobalPermuteResult: + """Dispatch tokens to the global expert axis. + + Returns a :class:`GlobalPermuteResult` suitable both for the + no-EP forward (where the same buffer feeds ``_expert_ffn`` + directly) and for the A2A-EP path (where the buffer is sliced + + sent over the EP axis before the FFN). The result carries the + per-backend opaque state needed to invert the dispatch in + :meth:`_global_combine`. + """ + num_tokens = inputs_2d.shape[0] + topk = self.num_experts_per_tok + + if self.permutation_backend == "pure_jax": + selected_experts, routing_weights = routing_map_to_selected_experts( + sparse_probs, routing_map, topk + ) + sorted_inputs, perm_state, group_sizes = pure_jax_token_dispatch( + inputs_2d, + selected_experts, + num_experts=self.num_experts, + num_experts_per_tok=topk, + align_size=self.align_size, + ) + return GlobalPermuteResult( + backend="pure_jax", + sorted_inputs=sorted_inputs, + group_sizes=group_sizes, + perm_state=perm_state, + routing_weights=routing_weights, + ) + + # triton + num_out_tokens = num_tokens * topk + align_size_arg = self.align_size if self.align_size > 0 else None + ( + sorted_inputs, + _permuted_probs, + row_id_map, + pad_offsets, + group_sizes, + ) = token_dispatch( + inputs_2d, + routing_map, + num_out_tokens=num_out_tokens, + probs=sparse_probs, + align_size=align_size_arg, + ) + return GlobalPermuteResult( + backend="triton", + sorted_inputs=sorted_inputs, + group_sizes=group_sizes, + row_id_map=row_id_map, + pad_offsets=pad_offsets, + merging_probs=sparse_probs, + ) + + # ------------------------------------------------------------------ + # Expert FFN (three grouped_dense calls + activation) + # ------------------------------------------------------------------ + + def _expert_ffn( + self, + sorted_inputs: jnp.ndarray, + group_sizes: jnp.ndarray, + n_groups: int, + wi_0: jnp.ndarray, + wi_1: jnp.ndarray, + wo: jnp.ndarray, + wi_0_bias: Optional[jnp.ndarray] = None, + wi_1_bias: Optional[jnp.ndarray] = None, + wo_bias: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + """Run the per-expert SwiGLU-style FFN over a permuted buffer. + + All ``wi_*`` / ``wo`` weights and the optional biases are passed + in as explicit args (rather than registered inline here) because + in the EP path this method runs *inside* a ``shard_map`` body + and Flax param registration must happen outside that scope. + + Parameters + ---------- + sorted_inputs : jnp.ndarray + Permuted tokens of shape ``[buffer_size, hidden]`` (rows + grouped by expert). + group_sizes : jnp.ndarray + Per-group token counts of shape ``[n_groups]``. + ``sum(group_sizes)`` must equal ``buffer_size`` (TE + ``grouped_dense`` FFI assertion at + ``transformer_engine/jax/csrc/extensions/gemm.cpp:1029``). + n_groups : int + Number of expert groups. Equals ``self.num_experts`` for the + no-EP path and ``num_experts // num_ep`` for the A2A-EP path. + Used to size the per-call quantizer set so the FP8 metadata + tensors match ``group_sizes``. + wi_0, wi_1, wo : jnp.ndarray + Expert weight tensors. Shapes (no-EP): + ``(num_experts, hidden, intermediate)`` for wi_*, + ``(num_experts, intermediate, hidden)`` for wo. Under EP + the leading expert dim is sliced to ``num_experts // num_ep``. + wi_0_bias, wi_1_bias, wo_bias : Optional[jnp.ndarray] + Optional per-expert biases (shape ``(num_experts, N)``); + ``grouped_dense`` adds ``bias[i]`` to the rows belonging to + expert ``i`` in the permuted layout. + + Returns + ------- + expert_outputs : jnp.ndarray + ``[buffer_size, hidden]``. + """ + # Each grouped_dense call gets its own quantizer_set with + # n_groups matching ``group_sizes``; this keeps the FP8 meta + # tensors correctly sized in both no-EP and A2A-EP cases. + q_set_w0 = self.generate_quantizer_set(postfix="_w0", n_groups=n_groups) + q_set_w1 = self.generate_quantizer_set(postfix="_w1", n_groups=n_groups) + q_set_wo = self.generate_quantizer_set(postfix="_wo", n_groups=n_groups) + + # Cast kernels to the activation dtype when no FP8 quantization + # is active (mirrors DenseGeneral). + if q_set_w0 == noop_quantizer_set: + wi_0 = wi_0.astype(sorted_inputs.dtype) + if q_set_w1 == noop_quantizer_set: + wi_1 = wi_1.astype(sorted_inputs.dtype) + if q_set_wo == noop_quantizer_set: + wo = wo.astype(sorted_inputs.dtype) + + layer_w0 = grouped_dense( + sorted_inputs, + wi_0, + group_sizes, + contracting_dims=((1,), (1,)), + bias=wi_0_bias, + quantizer_set=q_set_w0, + ) + layer_w1 = grouped_dense( + sorted_inputs, + wi_1, + group_sizes, + contracting_dims=((1,), (1,)), + bias=wi_1_bias, + quantizer_set=q_set_w1, + ) + + act_fn = _convert_to_activation_function(self.activation_type) + intermediate = act_fn(layer_w0) * layer_w1 + + expert_outputs = grouped_dense( + intermediate, + wo, + group_sizes, + contracting_dims=((1,), (1,)), + bias=wo_bias, + quantizer_set=q_set_wo, + ) + return expert_outputs + + # ------------------------------------------------------------------ + # Global combine (token combine -> back to [B, S, H]) + # ------------------------------------------------------------------ + + def _global_combine( + self, + expert_outputs: jnp.ndarray, + perm_result: GlobalPermuteResult, + batch_size: int, + sequence_length: int, + ) -> jnp.ndarray: + """Inverse of :meth:`_global_permute`. + + Gathers per-expert outputs back into ``[batch, sequence, hidden]`` + and applies the per-token weighted sum across the top-k experts. + """ + if perm_result.backend == "pure_jax": + return pure_jax_token_combine( + expert_outputs, + perm_result.perm_state, + perm_result.routing_weights, + num_experts_per_tok=self.num_experts_per_tok, + batch_size=batch_size, + sequence_length=sequence_length, + ) + # triton + out_2d = token_combine( + expert_outputs, + perm_result.row_id_map, + merging_probs=perm_result.merging_probs, + pad_offsets=perm_result.pad_offsets, + ) + hidden_size = out_2d.shape[-1] + return out_2d.reshape(batch_size, sequence_length, hidden_size).astype(self.dtype) + + # ------------------------------------------------------------------ + # No-EP forward + # ------------------------------------------------------------------ + + def _forward_no_ep( + self, + inputs: jnp.ndarray, + gate_logits: jnp.ndarray, + *, + wi_0: jnp.ndarray, + wi_1: jnp.ndarray, + wo: jnp.ndarray, + wi_0_bias: Optional[jnp.ndarray] = None, + wi_1_bias: Optional[jnp.ndarray] = None, + wo_bias: Optional[jnp.ndarray] = None, + expert_bias: Optional[jnp.ndarray] = None, + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """Single-shard or DP/FSDP/TP forward (no shard_map wrapper). + + DP / FSDP / TP all flow through each TE primitive's + ``custom_partitioning`` rule -- there is no cross-primitive + collective that the rules cannot express on their own, so a + ``shard_map`` is unnecessary here. + + Sharding contract for callers + ----------------------------- + + On this no-EP path the grouped quantize and grouped GEMMs run + in the caller's outer SPMD context (no ``shard_map`` boundary). + Their custom_partitioning rules read sharding from each input's + ``NamedSharding`` and propagate consistent shardings on outputs. + Concretely: + + * ``inputs`` should be FSDP/DP-sharded on the batch dim + (``input_axes`` in :class:`MoEBlock` enforces this via a + logical ``with_sharding_constraint``). + * ``wi_*`` / ``wo`` weights should carry the logical axes + ``wi_kernel_axes`` / ``wo_kernel_axes`` so FSDP shards a + weight non-contracting dim, gathered inside ``grouped_dense`` + before the GEMM. + * The wgrad reduce-scatter (when FSDP is active) is emitted by + ``grouped_dense_bwd``'s partitioning rule; no explicit + collective is needed here. + + Without those shardings the grouped GEMM falls back to + replicated-everywhere semantics (legal but defeats FSDP/DP). + Tested in ``tests/jax/test_distributed_moe_block.py`` for the + EP=2 + FSDP=2 case; the no-EP + FSDP-only case shares the same + infra and is covered when ``expert_parallelism_axis`` is left + ``None`` in that test. + """ + batch_size, sequence_length, hidden_size = inputs.shape + inputs_2d = inputs.reshape(-1, hidden_size) + logits_2d = gate_logits.reshape(-1, self.num_experts) + + sparse_probs, routing_map = self._route_topk(logits_2d, expert_bias) + # ``tokens_per_expert`` MUST come from the real routing_map so the + # aux-loss objective matches actual routing decisions under + # DeepSeek-style num_groups/group_topk routing. + tokens_per_expert = jnp.sum(routing_map.astype(jnp.int32), axis=0) + aux_loss = self._compute_aux_loss(logits_2d, tokens_per_expert) + perm = self._global_permute(inputs_2d, sparse_probs, routing_map) + expert_outputs = self._expert_ffn( + perm.sorted_inputs, + perm.group_sizes, + n_groups=self.num_experts, + wi_0=wi_0, + wi_1=wi_1, + wo=wo, + wi_0_bias=wi_0_bias, + wi_1_bias=wi_1_bias, + wo_bias=wo_bias, + ) + output = self._global_combine(expert_outputs, perm, batch_size, sequence_length) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + return output, aux_loss + + # ------------------------------------------------------------------ + # A2A (ragged-all-to-all) EP forward + # ------------------------------------------------------------------ + + def _forward_a2a_ep( + self, + inputs: jnp.ndarray, + gate_logits: jnp.ndarray, + *, + wi_0: jnp.ndarray, + wi_1: jnp.ndarray, + wo: jnp.ndarray, + wi_0_bias: Optional[jnp.ndarray] = None, + wi_1_bias: Optional[jnp.ndarray] = None, + wo_bias: Optional[jnp.ndarray] = None, + expert_bias: Optional[jnp.ndarray] = None, + ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: + """Wrap the body in a ``shard_map`` that runs a forward + ``ragged_all_to_all`` (A2A / A2Av) around the FFN. + + For each EP shard the wrapper: + + 1. Routes the shard's local tokens **globally** over all + ``num_experts`` experts (no roll, no local-mask -- every shard + sees the full expert axis). + 2. ``all_gather`` s its per-expert ``group_sizes`` so all shards + know the complete ``[num_ep, num_experts]`` token-count matrix. + 3. Forward ``ragged_all_to_all`` over the EP axis: each shard + sends per-expert chunks to the shard that owns those experts, + and receives chunks for its own ``num_experts // num_ep`` + local experts from every other shard. + 4. Reorders the received buffer from ``(source_shard, expert)`` + to ``(expert, source_shard)`` ordering so each local expert's + tokens are contiguous. + 5. Runs the three ``grouped_dense`` calls + activation over the + ``E_local``-group buffer. + 6. Reverses the local reorder. + 7. Reverse ``ragged_all_to_all`` over EP returns each shard's + token outputs to it. + 8. Inverts the global permute and applies the top-k weighted sum. + """ + from jax.experimental.shard_map import shard_map + + ep_axis = self.expert_parallelism_axis + if self.mesh is None: + raise ValueError( + "MoEBlock.expert_parallelism_axis is set; `mesh` must also" + " be provided so the EP shard_map can be built." + ) + mesh = self.mesh + num_ep = mesh.shape[ep_axis] + assert ( + self.num_experts % num_ep == 0 + ), f"num_experts={self.num_experts} must be divisible by EP size={num_ep}" + num_experts_local = self.num_experts // num_ep + + # Compose the BATCH sharding axis tuple. ``ep`` is always part of + # the batch axis (so ragged_all_to_all has data to route); any + # ``data_parallelism_axes`` are added on top so the per-device + # batch slice is genuinely unique (true FSDP / DP). + # Examples: + # data_parallelism_axes=() -> P('ep', None, None) + # data_parallelism_axes=('fsdp',) -> P(('ep','fsdp'), None, None) + # data_parallelism_axes=('fsdp','data') -> P(('ep','fsdp','data'), ...) + for ax in self.data_parallelism_axes: + if ax not in mesh.shape: + raise ValueError( + f"data_parallelism_axes contains {ax!r} but mesh has" + f" axes {tuple(mesh.shape.keys())}" + ) + if len(self.data_parallelism_axes) == 0: + batch_pspec_axis: Any = ep_axis + else: + batch_pspec_axis = (ep_axis, *self.data_parallelism_axes) + # The size by which the per-device batch is divided BEYOND ep. + # Used to tighten the worst-case ragged_all_to_all recv buffer: + # at most ``num_ep`` peers each send their entire local + # ``B/(num_ep*dp_size)*S*topk`` token-expert pairs, so the worst + # recv per device is ``num_ep * B/(num_ep*dp_size)*S*topk + # = B/dp_size * S * topk``. + dp_size = 1 + for ax in self.data_parallelism_axes: + dp_size *= mesh.shape[ax] + + global_batch_size, sequence_length, _hidden = inputs.shape + topk = self.num_experts_per_tok + # The shard_map's ``in_specs=P((ep, *dp_axes), ...)`` requires the + # batch dim to be divisible by ``num_ep * dp_size``; check upfront + # here for a clearer error than the one shard_map would raise at + # trace time. + batch_divisor = num_ep * dp_size + if global_batch_size % batch_divisor != 0: + raise ValueError( + f"batch={global_batch_size} not divisible by prod(data_parallelism_axes)={dp_size}" + ) + recv_buffer_rows = (global_batch_size // dp_size) * sequence_length * topk + + # Pack everything that crosses the shard_map boundary into a dict + # pytree. shard_map fully supports pytrees: ``in_specs`` must + # structurally match ``captured`` and we build them in lockstep + # so adding/removing an optional bias is one ``dict[name] = ...``. + # Params must be packed here (rather than passed inline by + # ``self.param`` inside the body) because Flax variable scopes + # must not be entered from inside a JAX transform's body. + captured: dict = { + "inputs": inputs, + "gate_logits": gate_logits, + "wi_0": wi_0, + "wi_1": wi_1, + "wo": wo, + } + in_specs: dict = { + "inputs": P(batch_pspec_axis, None, None), + "gate_logits": P(batch_pspec_axis, None, None), + "wi_0": P(ep_axis, None, None), + "wi_1": P(ep_axis, None, None), + "wo": P(ep_axis, None, None), + } + if expert_bias is not None: + captured["expert_bias"] = expert_bias + in_specs["expert_bias"] = P(ep_axis) + if wi_0_bias is not None: + captured["wi_0_bias"] = wi_0_bias + captured["wi_1_bias"] = wi_1_bias + captured["wo_bias"] = wo_bias + for name in ("wi_0_bias", "wi_1_bias", "wo_bias"): + in_specs[name] = P(ep_axis, None) + + a2a_body = partial( + self._a2a_body, + ep_axis=ep_axis, + num_ep=num_ep, + num_experts_local=num_experts_local, + recv_buffer_rows=recv_buffer_rows, + ) + + # ``check_rep=False`` disables shard_map's invariant that any + # output declared as ``P()`` is replicated across ``ep_axis``. + # We use ``axis_index(ep_axis)`` inside ``_a2a_body`` so the + # body is genuinely non-replicated, which would otherwise + # (correctly) fail the check. ``ragged_all_to_all`` already + # produces the right cross-shard semantics; this is the standard + # JAX escape hatch when collectives + per-shard logic coexist. + return shard_map( + a2a_body, + mesh=mesh, + in_specs=(in_specs,), + out_specs=(P(batch_pspec_axis, None, None), P()), + check_rep=False, + )(captured) + + # ------------------------------------------------------------------ + # Body of the per-shard A2A-EP forward (extracted from + # :meth:`_forward_a2a_ep` for readability). Runs *inside* the + # ``shard_map`` and is therefore in EP-manual mode: collectives over + # ``ep_axis`` are explicit, the rest of the mesh stays in auto mode. + # ------------------------------------------------------------------ + + def _a2a_body( + self, + local: dict, + *, + ep_axis: str, + num_ep: int, + num_experts_local: int, + recv_buffer_rows: int, + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + shard_id = jax.lax.axis_index(ep_axis) + + # -- Stage 1: per-shard route + global permute over all E -- + # Inside the shard_map body each input has its EP axis already + # consumed, so ``local_inputs.shape == [B/num_ep, S, H]``. + local_inputs = local["inputs"] + local_logits = local["gate_logits"] + local_b, local_s, local_h = local_inputs.shape + inputs_2d = local_inputs.reshape(-1, local_h) + logits_2d = local_logits.reshape(-1, self.num_experts) + + # The router operates over the full expert axis, so the + # EP-sharded ``expert_bias`` (in_spec ``P(ep_axis)``) must be + # all-gathered before being passed in. + if "expert_bias" in local: + full_expert_bias = jax.lax.all_gather( + local["expert_bias"], axis_name=ep_axis, tiled=True + ) + else: + full_expert_bias = None + sparse_probs, routing_map = self._route_topk(logits_2d, full_expert_bias) + + # aux_loss must see the global token batch and the global + # tokens_per_expert: its formula ``E*coeff/(k*T^2) * sum_i( + # sum_t(probs[t,i]) * tokens[i])`` is not shard-decomposable + # (the sum_t * tokens product is data-dependent across + # shards). We need a *single* collective: + # * ``all_gather`` logits over (ep + any DP axes) so both + # (a) the score-for-aux kernel and (b) a re-run of + # ``_route_topk`` see the full token batch. The re-run + # gives us the global per-expert token count directly, + # avoiding a separate ``psum``. Two consecutive global + # collectives over the same replica group at the very + # start of the program have been observed to deadlock + # under FP8 autocast on some XLA + NCCL combinations, + # so we keep this branch to one collective. + # The aux branch has no data dependency on the main routing + # path beyond what is already gathered, so XLA can overlap + # the two routings on the GPU. + if self.aux_loss_coeff > 0.0: + # ``axis_name`` accepts a tuple ⇒ a single collective + # over the cartesian product of axes; XLA may lower + # this to one multi-axis op or split it. + if len(self.data_parallelism_axes) == 0: + aux_collective_axes: Any = ep_axis + else: + aux_collective_axes = (ep_axis, *self.data_parallelism_axes) + global_logits_2d = jax.lax.all_gather( + logits_2d, axis_name=aux_collective_axes, axis=0, tiled=True + ) + # Re-run topk on the gathered logits to obtain the + # *global* routing_map post-grouping (respects + # num_groups/group_topk/expert_bias just like the local + # routing). Summing over the global token dim gives the + # exact same counts as ``psum(local_tokens_per_expert)`` + # without an extra collective. The duplicate topk + # compute is small relative to the FFNs. + _, global_routing_map = self._route_topk(global_logits_2d, full_expert_bias) + global_tokens_per_expert = jnp.sum(global_routing_map.astype(jnp.int32), axis=0) + aux_loss = self._compute_aux_loss(global_logits_2d, global_tokens_per_expert) + else: + aux_loss = None + + perm = self._global_permute(inputs_2d, sparse_probs, routing_map) + global_group_sizes = perm.group_sizes # [E] + + # -- Stage 2: gather per-expert counts across the EP axis -- + all_shards_tokens_per_expert = jax.lax.all_gather( + global_group_sizes[None, :], + axis_name=ep_axis, + axis=0, + tiled=True, + ) # [num_ep, num_experts] + + # -- Stage 3: forward ragged_all_to_all over EP -- + in_off, send_sz, out_off, recv_sz = compute_ragged_all_to_all_params( + all_shards_tokens_per_expert, shard_id, num_ep + ) + recv_buf = jnp.zeros( + (recv_buffer_rows, local_h), + dtype=perm.sorted_inputs.dtype, + ) + x_recv = jax.lax.ragged_all_to_all( + perm.sorted_inputs, + recv_buf, + in_off, + send_sz, + out_off, + recv_sz, + axis_name=ep_axis, + ) + + # -- Stage 4: local permute (source_shard, expert) -> (expert, shard) + sorted_x, local_group_sizes, local_perm_state = local_permute_after_a2a( + x_recv, + all_shards_tokens_per_expert, + shard_id, + num_ep, + ) + + # -- Stage 5: per-expert FFN (E_local groups) -- + expert_outputs = self._expert_ffn( + sorted_x, + local_group_sizes, + n_groups=num_experts_local, + wi_0=local["wi_0"], + wi_1=local["wi_1"], + wo=local["wo"], + wi_0_bias=local.get("wi_0_bias"), + wi_1_bias=local.get("wi_1_bias"), + wo_bias=local.get("wo_bias"), + ) + + # -- Stage 6: invert local permute -- + x_send_back = local_unpermute_before_a2a(expert_outputs, local_perm_state) + + # -- Stage 7: reverse ragged_all_to_all over EP -- + in_off_r, send_sz_r, out_off_r, recv_sz_r = compute_reverse_ragged_all_to_all_params( + all_shards_tokens_per_expert, shard_id, num_ep + ) + send_back_buf = jnp.zeros_like(perm.sorted_inputs) + y_back = jax.lax.ragged_all_to_all( + x_send_back, + send_back_buf, + in_off_r, + send_sz_r, + out_off_r, + recv_sz_r, + axis_name=ep_axis, + ) + + # -- Stage 8: invert global permute, weighted sum over top-k -- + output = self._global_combine(y_back, perm, batch_size=local_b, sequence_length=local_s) + + if self.tensor_parallelism_axis is not None: + output = jax.lax.psum_scatter( + output, + self.tensor_parallelism_axis, + scatter_dimension=2, + tiled=True, + ) + + # ``out_specs`` must match the returned pytree structurally, + # so always emit a real scalar for aux_loss; the outer + # ``__call__`` re-strips it to None when aux_loss_coeff <= 0. + if aux_loss is None: + aux_loss = jnp.zeros((), dtype=self.dtype) + return output, aux_loss diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 81972aac0f..9fbaf64736 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -7,6 +7,19 @@ This module provides high-level token dispatch and combine operations for Mixture of Experts (MoE) models with proper automatic differentiation support. +Two backends are offered: + +* Triton-backed ``token_dispatch`` / ``token_combine`` - uses the + Triton kernels in ``transformer_engine.jax.triton_extensions.permutation``. +* Pure-JAX ``pure_jax_token_dispatch`` / ``pure_jax_token_combine`` - uses + only ``jnp.argsort`` + gather and is therefore compiled as plain XLA. + Despite the name, this path is often *faster* than the Triton kernels in + current testing because XLA can fuse the ops with surrounding work. + +Both backends support optional alignment padding (``align_size > 0``) so each +expert's group size is a multiple of ``align_size``, which is required for +quantized grouped GEMMs. + Token Dispatch (Permute): - Forward: Permute tokens according to routing map (scatter to experts) - Backward: Unpermute gradients (gather from experts) @@ -17,7 +30,7 @@ """ from functools import partial -from typing import Optional, Tuple +from typing import NamedTuple, Optional, Tuple import jax import jax.numpy as jnp @@ -38,6 +51,15 @@ "token_dispatch", "token_combine", "sort_chunks_by_index", + "pure_jax_token_dispatch", + "pure_jax_token_combine", + "PureJaxPermState", + # Ragged-all-to-all expert-parallelism helpers + "compute_ragged_all_to_all_params", + "compute_reverse_ragged_all_to_all_params", + "local_permute_after_a2a", + "local_unpermute_before_a2a", + "routing_map_to_selected_experts", ] @@ -655,3 +677,642 @@ def _sort_chunks_by_index_bwd_rule( _sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule) + + +# ============================================================================= +# Pure-JAX token dispatch / combine +# ============================================================================= +# +# The following implementations use only ``jnp.argsort`` + gather and compile +# to plain XLA. They are a drop-in alternative to ``token_dispatch`` / +# ``token_combine`` above, differing only in input/output conventions (the +# Triton path takes ``routing_map`` and ``sparse_probs`` over all experts; the +# pure-JAX path takes dense ``selected_experts`` and per-token ``weights`` of +# shape ``[..., topk]``). +# +# Note: despite Triton being fused and pure-JAX being a sequence of XLA ops, +# the pure-JAX backend is often *faster* in current testing because XLA can +# fuse these ops into the surrounding work. + + +# ----------------------------------------------------------------------------- +# Custom-VJP argsort-based gather. +# +# ``inputs[sort_indices]`` has a known inverse: ``output[argsort(sort_indices)]``. +# Using a custom VJP lets the backward pass exploit that inverse instead of +# relying on the compiler to discover it from the scatter-style default +# gradient of a gather, which is typically less efficient. + + +@jax.custom_vjp +def _sort_activations(inputs: jax.Array, sort_indices: jax.Array) -> jax.Array: + """Sort ``inputs`` along the leading dim by ``sort_indices``.""" + assert ( + inputs.shape[0] == sort_indices.shape[0] + ), f"inputs.shape[0]={inputs.shape[0]} must match sort_indices.shape[0]={sort_indices.shape[0]}" + with jax.named_scope("pure_jax_sort_activations"): + return inputs[sort_indices, ...] + + +def _sort_activations_fwd( + inputs: jax.Array, sort_indices: jax.Array +) -> Tuple[jax.Array, jax.Array]: + return _sort_activations(inputs, sort_indices), sort_indices + + +def _sort_activations_bwd(residuals: jax.Array, grads: jax.Array) -> Tuple[jax.Array, None]: + sort_indices = residuals + # Inverse permutation: gather-by-argsort undoes the forward gather. + return _sort_activations(grads, jnp.argsort(sort_indices)), None + + +_sort_activations.defvjp(_sort_activations_fwd, _sort_activations_bwd) + + +def routing_map_to_selected_experts( + sparse_probs: jnp.ndarray, + routing_map: jnp.ndarray, + topk: int, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Convert ``(sparse_probs, routing_map)`` from TE's fused router to the + ``(selected_experts, weights)`` format consumed by + :func:`pure_jax_token_dispatch`. + + ``routing_map`` is a boolean mask of shape ``[num_tokens, num_experts]`` + with exactly ``topk`` ``True`` positions per row. + """ + # Argsort on a bool tensor places ``True`` rows last (False=0 < True=1), + # so the last ``topk`` indices are the selected expert IDs. + selected_experts = jnp.argsort(routing_map, axis=-1)[..., -topk:] + weights = jnp.take_along_axis(sparse_probs, selected_experts, axis=-1) + return selected_experts, weights + + +# ----------------------------------------------------------------------------- +# Permutation state carried from dispatch to combine. + + +class PureJaxPermState(NamedTuple): + """Opaque state produced by :func:`pure_jax_token_dispatch`. + + Attributes + ---------- + sorted_indices : jnp.ndarray + The argsort indices used in the forward sort. Needed to reverse the + permutation in :func:`pure_jax_token_combine`. Shape + ``[num_real_tokens + padding_size]``. + num_real_tokens : int + Number of real (non-padding) permuted tokens, i.e. + ``batch_size * sequence_length * num_experts_per_tok``. Compile-time + constant. + padding_size : int + Number of alignment-padding tokens appended to the sort buffer. Equals + ``num_experts * (align_size - 1)`` when ``align_size > 0``, else ``0``. + Compile-time constant. + """ + + sorted_indices: jax.Array + num_real_tokens: int + padding_size: int + + +# ----------------------------------------------------------------------------- +# Dispatch (permute) + + +def pure_jax_token_dispatch( + inputs: jnp.ndarray, + selected_experts: jnp.ndarray, + num_experts: int, + num_experts_per_tok: int, + align_size: int = 0, + roll_to_expert_id: Optional[int] = None, +) -> Tuple[jnp.ndarray, PureJaxPermState, jnp.ndarray]: + """Pure-JAX ``argsort``-based token dispatch. + + Parameters + ---------- + inputs : jnp.ndarray + Input tensor of shape ``[num_tokens, hidden_size]`` (or + ``[batch, seq, hidden]``; it will be flattened). + selected_experts : jnp.ndarray + Per-token expert IDs, shape ``[num_tokens, num_experts_per_tok]`` (or + ``[batch, seq, num_experts_per_tok]``). Integer dtype. + num_experts : int + Total number of experts. + num_experts_per_tok : int + Top-k. Must equal ``selected_experts.shape[-1]``. + align_size : int, default 0 + Alignment for each expert's group size. ``0`` disables padding; a value + ``> 0`` appends a static-size padding buffer so each resulting group + size is a multiple of ``align_size`` (required for quantized grouped + GEMM). + roll_to_expert_id : Optional[int] + If provided, rotates expert IDs by ``-roll_to_expert_id`` modulo + ``num_experts`` before the sort (ring-of-experts EP). The returned + ``group_sizes`` is rolled to match. + + Returns + ------- + sorted_inputs : jnp.ndarray + Permuted tokens grouped by expert, shape + ``[num_real_tokens + padding_size, hidden_size]``. + perm_state : PureJaxPermState + State needed by :func:`pure_jax_token_combine`. + group_sizes : jnp.ndarray + Token count per expert, shape ``[num_experts]``. Each entry is a + multiple of ``align_size`` when ``align_size > 0``. + """ + assert num_experts_per_tok == selected_experts.shape[-1], ( + f"num_experts_per_tok={num_experts_per_tok} must match" + f" selected_experts.shape[-1]={selected_experts.shape[-1]}" + ) + assert align_size >= 0, f"align_size must be >= 0, got {align_size}" + + hidden_size = inputs.shape[-1] + inputs_2d = inputs.reshape(-1, hidden_size) + num_tokens = inputs_2d.shape[0] + num_real_tokens = num_tokens * num_experts_per_tok + + flatten_selected_experts = jnp.ravel(selected_experts) + + if align_size > 0: + # Per-expert token count, and how many extra tokens each expert needs + # to become aligned to ``align_size``. Using + # ``(align - count % align) % align`` gives 0 (not ``align``) when + # already aligned, so we never exceed the per-expert slot capacity of + # ``align_size - 1``. + token_count_per_expert = jnp.bincount(flatten_selected_experts, length=num_experts) + padding_tokens_required_per_expert = ( + align_size - (token_count_per_expert % align_size) + ) % align_size + + # Build a static-size padding buffer of shape + # ``[num_experts * (align_size - 1)]``. Each expert ``i`` owns a slot + # of ``align_size - 1`` positions (worst-case padding, which occurs + # when ``token_count[i] % align_size == 1``). Within slot ``i``, + # positions ``[0, padding_needed)`` are assigned expert ``i`` and act + # as real padding; the rest are assigned to ``num_experts - 1`` as + # overflow placeholders that keep the buffer statically sized for JIT. + max_padding_per_expert = align_size - 1 + max_total_padding_size = num_experts * max_padding_per_expert + positions = jnp.arange(max_total_padding_size) + expert_for_pos = positions // max_padding_per_expert + offset_in_slot = positions % max_padding_per_expert + padding_needed = padding_tokens_required_per_expert[expert_for_pos] + flatten_padding_selected_experts = jnp.where( + offset_in_slot < padding_needed, + expert_for_pos, + num_experts - 1, + ) + + flatten_selected_experts = jnp.concatenate( + [flatten_selected_experts, flatten_padding_selected_experts], axis=0 + ) + + if roll_to_expert_id is not None: + flatten_selected_experts = (flatten_selected_experts - roll_to_expert_id) % num_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) + # Pad inputs with zeros so the sort operand shape matches the expanded + # selected-experts vector. + replicated_inputs_2d = jnp.pad( + replicated_inputs_2d, + pad_width=((0, max_total_padding_size), (0, 0)), + mode="constant", + constant_values=0.0, + ) + + sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts) + + # Compute ``group_sizes`` directly from counts rather than via + # ``bincount(flatten_selected_experts)``: the overflow placeholder + # tokens would inflate ``group_sizes[num_experts - 1]``, breaking the + # alignment guarantee. Direct computation gives each expert exactly + # ``ceil(count / align) * align`` tokens. + group_sizes = token_count_per_expert + padding_tokens_required_per_expert + + if roll_to_expert_id is not None: + group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) + + padding_size = max_total_padding_size + else: + if roll_to_expert_id is not None: + flatten_selected_experts = (flatten_selected_experts - roll_to_expert_id) % num_experts + + sorted_selected_experts = jnp.argsort(flatten_selected_experts) + + replicated_inputs_2d = jnp.repeat(inputs_2d, num_experts_per_tok, axis=0) + sorted_inputs = _sort_activations(replicated_inputs_2d, sorted_selected_experts) + + group_sizes = jnp.bincount(flatten_selected_experts, length=num_experts) + if roll_to_expert_id is not None: + group_sizes = jnp.roll(group_sizes, -roll_to_expert_id) + + padding_size = 0 + + perm_state = PureJaxPermState( + sorted_indices=sorted_selected_experts, + num_real_tokens=num_real_tokens, + padding_size=padding_size, + ) + return sorted_inputs, perm_state, group_sizes + + +# ----------------------------------------------------------------------------- +# Combine (unpermute + weighted sum) + + +def pure_jax_token_combine( + expert_outputs: jnp.ndarray, + perm_state: PureJaxPermState, + routing_weights: jnp.ndarray, + num_experts_per_tok: int, + batch_size: int, + sequence_length: int, +) -> jnp.ndarray: + """Pure-JAX ``argsort``-based token combine. + + Reverses the permutation performed by :func:`pure_jax_token_dispatch`, + strips any alignment-padding rows appended during dispatch, and applies a + per-token weighted sum across the top-k experts. + + Parameters + ---------- + expert_outputs : jnp.ndarray + Output of the expert FFN, shape + ``[num_real_tokens + padding_size, hidden_size]``. + perm_state : PureJaxPermState + State returned by :func:`pure_jax_token_dispatch`. + routing_weights : jnp.ndarray + Top-k routing weights, shape ``[batch*seq, num_experts_per_tok]`` + (or broadcastable to it after a ``reshape``). + num_experts_per_tok : int + Top-k. + batch_size : int + Original batch size. + sequence_length : int + Original sequence length. + + Returns + ------- + output : jnp.ndarray + Combined output tensor of shape ``[batch_size, sequence_length, hidden_size]``. + """ + # Reverse the permutation: ``output[argsort(sorted_indices)]`` undoes + # ``input[sorted_indices]``. + unsort_intermediate = _sort_activations( + expert_outputs, + jnp.argsort(perm_state.sorted_indices), + ) + + # Strip alignment padding tokens appended during dispatch. After unsorting, + # the first ``num_real_tokens`` rows hold the real per-(token, top-k) + # outputs; any trailing rows are padding placeholders (zeros) and must be + # discarded before the reshape below. + if perm_state.padding_size > 0: + unsort_intermediate = unsort_intermediate[: perm_state.num_real_tokens] + + hidden_size = unsort_intermediate.shape[-1] + reshaped_weights = jnp.reshape(routing_weights, (-1, num_experts_per_tok)) + reshaped_intermediate = jnp.reshape( + unsort_intermediate, (reshaped_weights.shape[0], num_experts_per_tok, hidden_size) + ) + + # Cast weights to match intermediate dtype (weighted sum happens in + # intermediate dtype; callers can upcast before calling if higher + # precision weight-sum is desired). + reshaped_weights = reshaped_weights.astype(reshaped_intermediate.dtype) + with jax.named_scope("pure_jax_weight_sum"): + output = jnp.einsum( + "BKE,BK -> BE", + reshaped_intermediate, + reshaped_weights, + ) + return output.reshape(batch_size, sequence_length, hidden_size) + + +# ============================================================================= +# Ragged-all-to-all expert-parallelism helpers +# ============================================================================= +# +# These helpers support the ragged-all-to-all (A2A / A2Av) EP strategy used by +# :class:`transformer_engine.jax.flax.MoEBlock`. The forward EP path looks +# like:: +# +# route -> global_permute -> AG(group_sizes, ep) +# -> ragged_all_to_all(fwd, ep) +# -> local_permute_after_a2a +# -> grouped_dense x3 + activation +# -> local_unpermute_before_a2a +# -> ragged_all_to_all(reverse, ep) +# -> global_combine +# +# The two ``compute_*_ragged_all_to_all_params`` functions translate +# ``all_shards_tokens_per_expert`` (an EP-axis ``all_gather`` of each shard's +# global ``group_sizes``) into the four ``ragged_all_to_all`` arguments +# (``input_offsets``, ``send_sizes``, ``output_offsets``, ``recv_sizes``). +# ``shard_id`` may be a traced value (e.g. from :func:`jax.lax.axis_index`), +# which is why every slice into ``all_shards_tokens_per_expert`` uses +# :func:`jax.lax.dynamic_slice`. +# +# These functions are pure JAX (no MaxText / TE dependencies) and equivalent +# to :func:`maxtext.layers.te_permutation.compute_ragged_all_to_all_params` +# / :func:`compute_reverse_ragged_all_to_all_params`. + + +def compute_ragged_all_to_all_params( + all_shards_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_expert_shards: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Forward-direction ragged_all_to_all parameters. + + Computes the four index/size arrays that :func:`jax.lax.ragged_all_to_all` + consumes for the **forward** EP shuffle, where each shard sends its + expert-grouped tokens to the shard that owns those experts. + + Parameters + ---------- + all_shards_tokens_per_expert : jnp.ndarray + Per-shard, per-expert token counts gathered across the EP axis. Shape + ``[num_expert_shards, num_experts]`` and integer dtype. + shard_id : jnp.ndarray + Index of the current shard along the EP axis (typically + :func:`jax.lax.axis_index` of the EP axis). Must be a 0-d integer. + num_expert_shards : int + Static EP-axis size. Must match + ``all_shards_tokens_per_expert.shape[0]``. + + Returns + ------- + input_offsets : jnp.ndarray + Shape ``[num_expert_shards]``. Cumulative ``send_sizes`` (with a + leading 0) -- where in the local source buffer each destination + shard's chunk begins. + send_sizes : jnp.ndarray + Shape ``[num_expert_shards]``. ``send_sizes[i]`` is the number of + tokens this shard sends to shard ``i`` (= the sum of token counts + for the experts owned by shard ``i``). + output_offsets : jnp.ndarray + Shape ``[num_expert_shards]``. ``output_offsets[i]`` is the row in + shard ``i``'s receive buffer where this shard's contribution should + land. Sender-side semantics, per :func:`jax.lax.ragged_all_to_all`. + recv_sizes : jnp.ndarray + Shape ``[num_expert_shards]``. ``recv_sizes[i]`` is the number of + tokens shard ``i`` sends to this shard. + """ + num_experts = all_shards_tokens_per_expert.shape[1] + assert ( + num_experts % num_expert_shards == 0 + ), f"num_experts={num_experts} must be divisible by num_expert_shards={num_expert_shards}" + local_expert_size = num_experts // num_expert_shards + + # This shard's row of the gathered table, reshaped so axis 0 indexes the + # destination shard and axis 1 indexes its local experts. + local_tokens_per_expert = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(shard_id, 0), + slice_sizes=(1, num_experts), + ).squeeze(0) + local_reshaped = local_tokens_per_expert.reshape(num_expert_shards, local_expert_size) + + # send_sizes[i] = sum of token counts for shard i's experts in our buffer. + send_sizes = jnp.sum(local_reshaped, axis=1) + input_offsets = jnp.concatenate( + [ + jnp.array([0], dtype=send_sizes.dtype), + jnp.cumsum(send_sizes)[:-1], + ] + ) + + # recv_sizes[i] = how many tokens shard i sends to this shard, i.e. the + # sum across our local-expert columns of shard i's row. + local_expert_start = shard_id * local_expert_size + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_expert_shards, local_expert_size), + ) + recv_sizes = jnp.sum(local_expert_columns, axis=1) + + # output_offsets uses sender-side semantics for ragged_all_to_all: + # output_offsets[j] = row in shard j's buffer where THIS shard's chunk + # should be placed. That's the cumulative sum (over source shards 0..j-1) + # of how many tokens those earlier source shards already sent to shard j. + sends_to_target = jnp.sum( + all_shards_tokens_per_expert.reshape( + num_expert_shards, num_expert_shards, local_expert_size + ), + axis=2, + ) # [src_shard, dst_shard] + zero_row = jnp.zeros((1, num_expert_shards), dtype=sends_to_target.dtype) + cumulated = jnp.cumsum( + jnp.concatenate([zero_row, sends_to_target], axis=0), + axis=0, + dtype=sends_to_target.dtype, + ) # [src_shard + 1, dst_shard]; row r = total sent by sources 0..r-1 + output_offsets = jax.lax.dynamic_slice( + cumulated, + start_indices=(shard_id, 0), + slice_sizes=(1, num_expert_shards), + ).squeeze(0) + + return input_offsets, send_sizes, output_offsets, recv_sizes + + +def compute_reverse_ragged_all_to_all_params( + all_shards_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_expert_shards: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Reverse-direction ragged_all_to_all parameters. + + Mirror of :func:`compute_ragged_all_to_all_params` for the **reverse** + EP shuffle that returns expert outputs to their source shards. The + sender / receiver roles are swapped: what we received in the forward + shuffle we now send back, and vice versa. + + Parameters and shapes are identical to + :func:`compute_ragged_all_to_all_params`. + """ + num_experts = all_shards_tokens_per_expert.shape[1] + assert ( + num_experts % num_expert_shards == 0 + ), f"num_experts={num_experts} must be divisible by num_expert_shards={num_expert_shards}" + local_expert_size = num_experts // num_expert_shards + + local_expert_start = shard_id * local_expert_size + + # In reverse, what we received becomes what we send. send_sizes[i] is how + # many tokens we send back to source shard i (= what shard i originally + # sent us, summed across our local experts). + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_expert_shards, local_expert_size), + ) + send_sizes = jnp.sum(local_expert_columns, axis=1) + input_offsets = jnp.concatenate( + [ + jnp.array([0], dtype=send_sizes.dtype), + jnp.cumsum(send_sizes)[:-1], + ] + ) + + # recv_sizes[i] = how many tokens we receive back from shard i (= what + # we originally sent to shard i in the forward). + local_tokens_per_expert = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(shard_id, 0), + slice_sizes=(1, num_experts), + ).squeeze(0) + local_reshaped = local_tokens_per_expert.reshape(num_expert_shards, local_expert_size) + recv_sizes = jnp.sum(local_reshaped, axis=1) + + # output_offsets: the reverse sends-to-target matrix is the transpose of + # the forward one (row i = what shard i sends in reverse = what shard i + # received in forward). Cumsum down source-shard axis, then index our row. + fwd_sends_to = jnp.sum( + all_shards_tokens_per_expert.reshape( + num_expert_shards, num_expert_shards, local_expert_size + ), + axis=2, + ) # forward: [src, dst] + rev_sends_to = jnp.transpose(fwd_sends_to) # reverse: [src, dst] + zero_row = jnp.zeros((1, num_expert_shards), dtype=rev_sends_to.dtype) + rev_cumulated = jnp.cumsum( + jnp.concatenate([zero_row, rev_sends_to], axis=0), + axis=0, + dtype=rev_sends_to.dtype, + ) + output_offsets = jax.lax.dynamic_slice( + rev_cumulated, + start_indices=(shard_id, 0), + slice_sizes=(1, num_expert_shards), + ).squeeze(0) + + return input_offsets, send_sizes, output_offsets, recv_sizes + + +# ----------------------------------------------------------------------------- +# Local permute / unpermute +# ----------------------------------------------------------------------------- +# +# After the forward ragged_all_to_all the receive buffer is laid out as +# ``[from_shard_0_chunk | from_shard_1_chunk | ... ]`` and within each chunk +# tokens are sorted by local-expert id. To feed ``grouped_dense`` we want +# ``[expert_0_block | expert_1_block | ... ]`` where each expert's block +# contains tokens from every source shard. ``local_permute_after_a2a`` +# performs that reorder; ``local_unpermute_before_a2a`` undoes it before the +# reverse ragged_all_to_all. +# +# Implementation uses :func:`sort_chunks_by_index`, which is Triton-backed +# (see ``transformer_engine.jax.triton_extensions.permutation``) and has a +# paired custom-VJP backward. There is no pure-JAX alternative here -- the +# global :func:`pure_jax_token_dispatch` / :func:`token_dispatch` choice is +# unaffected by this; only the (small) post-A2A chunk reorder uses Triton +# unconditionally. + + +def local_permute_after_a2a( + x_recv: jnp.ndarray, + all_shards_tokens_per_expert: jnp.ndarray, + shard_id: jnp.ndarray, + num_expert_shards: int, +) -> Tuple[jnp.ndarray, jnp.ndarray, dict]: + """Reorder tokens received via ragged_all_to_all so each local expert's + tokens are contiguous. + + This is the EP-side complement to the global :func:`token_dispatch` / + :func:`pure_jax_token_dispatch`. Internally uses + :func:`sort_chunks_by_index` (Triton-backed) for both the forward sort + and -- via :func:`local_unpermute_before_a2a` -- the inverse. + + Parameters + ---------- + x_recv : jnp.ndarray + Output of the forward ``ragged_all_to_all`` of shape + ``[buffer_size, hidden_size]``. Layout: source-shard major, then + local-expert id within each source chunk. + all_shards_tokens_per_expert : jnp.ndarray + Per-shard, per-expert token counts of shape + ``[num_expert_shards, num_experts]``. + shard_id : jnp.ndarray + Current EP shard index (typically a traced + :func:`jax.lax.axis_index`). + num_expert_shards : int + Static EP-axis size. + + Returns + ------- + sorted_x : jnp.ndarray + Tokens reordered into expert-major layout. Same shape as ``x_recv``. + local_group_sizes : jnp.ndarray + Per-local-expert token counts of shape ``[local_expert_size]``. + state : dict + Opaque state for :func:`local_unpermute_before_a2a`. + """ + num_experts = all_shards_tokens_per_expert.shape[1] + assert ( + num_experts % num_expert_shards == 0 + ), f"num_experts={num_experts} must be divisible by num_expert_shards={num_expert_shards}" + local_expert_size = num_experts // num_expert_shards + local_expert_start = shard_id * local_expert_size + local_expert_columns = jax.lax.dynamic_slice( + all_shards_tokens_per_expert, + start_indices=(0, local_expert_start), + slice_sizes=(num_expert_shards, local_expert_size), + ) + + # Flat sizes in source-major order, matching the receive buffer layout: + # [(s0,e0), (s0,e1), ..., (s1,e0), (s1,e1), ...] + split_sizes = local_expert_columns.reshape(-1) + + # Permutation that maps source-major -> expert-major: + # original index = s * E_local + e + # target index = e * num_shards + s + indices_matrix = jnp.arange(num_expert_shards * local_expert_size, dtype=jnp.int32).reshape( + num_expert_shards, local_expert_size + ) + sorted_chunk_indices = indices_matrix.T.reshape(-1) + + sorted_x, _ = sort_chunks_by_index(x_recv, split_sizes, sorted_chunk_indices) + sorted_split_sizes = split_sizes[sorted_chunk_indices] + inverse_chunk_indices = jnp.argsort(sorted_chunk_indices) + local_group_sizes = jnp.sum(local_expert_columns, axis=0) + state = { + "sorted_split_sizes": sorted_split_sizes, + "inverse_chunk_indices": inverse_chunk_indices, + } + return sorted_x, local_group_sizes, state + + +def local_unpermute_before_a2a( + expert_outputs: jnp.ndarray, + state: dict, +) -> jnp.ndarray: + """Inverse of :func:`local_permute_after_a2a`. + + Parameters + ---------- + expert_outputs : jnp.ndarray + Output of the local expert FFN of shape ``[buffer_size, hidden_size]``, + in expert-major layout. + state : dict + Opaque state returned by :func:`local_permute_after_a2a`. + + Returns + ------- + unsorted_x : jnp.ndarray + Tokens reordered back into source-shard-major layout, ready for the + reverse ``ragged_all_to_all``. Same shape as ``expert_outputs``. + """ + out, _ = sort_chunks_by_index( + expert_outputs, + state["sorted_split_sizes"], + state["inverse_chunk_indices"], + ) + return out