Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/reference/core_concepts/moe_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ Dropping:

`mlp_bias`: If enabled, add learnable bias terms for MLP matmul. Originally implemented to support the GPT-OSS model architecture.

`prefuse_moe_weights`: If enabled alongside `sparse_matmul=True`, fuses the two FFN1 grouped GEMMs (wi\_0 and wi\_1) into a single grouped GEMM call. Expert weights are stored in a concatenated `(num_experts, embed_dim, 2 * mlp_dim)` shape, so input activations are loaded from HBM once per forward pass instead of twice. Backend-agnostic (works with Megablox, JAX Ragged Dot, and Tokamax). When used with `attention=vllm_rpa`, the fused weight tensor is passed directly to the vLLM-TPU serving kernel without splitting.

`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications that yields performance benefits.

## 2. Sharding
Expand Down
74 changes: 44 additions & 30 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,13 +444,7 @@ def __init__(
self.wi_0 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
self.wi_1 = jnp.zeros((num_experts, self.moe_expert_input_dim, intermediate_dim))
self.wo = jnp.zeros((num_experts, intermediate_dim, self.moe_expert_input_dim))
elif self.config.prefuse_moe_weights and self.config.attention == "vllm_rpa":
# Pad model dimension in Fused MoE weight kernels for GMM_v2 execution.
moe_intermediate_dim = (
self.config.padded_base_moe_mlp_dim
if self.config.padded_base_moe_mlp_dim is not None
else self.intermediate_dim
)
elif self.config.prefuse_moe_weights:
self.wi = nnx.Param(
self.kernel_init(
self.rngs.params(),
Expand Down Expand Up @@ -1412,29 +1406,44 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
self.config.wo_tile_drhs_embed_dim, # Called n in megablox, and indeed is the RHS batch dim
)

layer_w0 = gmm_fn(
x,
w0,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")

layer_w1 = gmm_fn(
x,
w1,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
if self.config.mlp_bias:
layer_w1 = layer_w1 + w1_bias
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
if self.config.prefuse_moe_weights:
# Weights are stored as (G,K,2N); w0/w1 are adjacent slices so XLA elides this concat.
w_fused = jnp.concatenate([w0, w1], axis=-1)
Comment thread
abhinavgoel95 marked this conversation as resolved.
out = gmm_fn(x, w_fused, tiling=wi_tile_size, weight_gather_axes=wi_gather_axes)
n = out.shape[-1] // 2
layer_w0, layer_w1 = out[:, :n], out[:, n:]
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w1 = layer_w1 + w1_bias
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
else:
layer_w0 = gmm_fn(
x,
w0,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")

layer_w1 = gmm_fn(
x,
w1,
tiling=wi_tile_size,
weight_gather_axes=wi_gather_axes,
)
if self.get_tensor_transpose_parallelism_size() > 1:
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
if self.config.mlp_bias:
layer_w1 = layer_w1 + w1_bias
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)

intermediate_output = gmm_fn(
Expand Down Expand Up @@ -2238,6 +2247,11 @@ def __call__(
w1_kernel = None
if cfg.prefuse_moe_weights and cfg.attention == "vllm_rpa":
fused_kernel = jnp.asarray(self.wi[...], self.dtype)
elif cfg.prefuse_moe_weights:
wi = jnp.asarray(self.wi[...], self.dtype)
n = wi.shape[-1] // 2
w0_kernel = wi[..., :n]
w1_kernel = wi[..., n:]
else:
w0_kernel = jnp.asarray(self.wi_0[...], self.dtype)
w1_kernel = jnp.asarray(self.wi_1[...], self.dtype)
Expand Down
144 changes: 102 additions & 42 deletions tests/unit/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,42 +68,38 @@ def setUp(self):
def test_generate_masks(self):
# expert_capacity = (tokens_per_batch / num_experts) * capacity_factor
# expert_capacity_in_batch = (4 * 2 / 8) * 2 = 2
top_k_indices = jnp.array(
top_k_indices = jnp.array([
[[0, 5], [0, 4], [1, 0], [3, 5]],
[[1, 2], [4, 1], [5, 0], [7, 1]],
[[6, 2], [2, 3], [4, 2], [1, 2]],
[[4, 1], [0, 7], [5, 0], [4, 7]],
])
softmax_probs = jnp.array([
[
[[0, 5], [0, 4], [1, 0], [3, 5]],
[[1, 2], [4, 1], [5, 0], [7, 1]],
[[6, 2], [2, 3], [4, 2], [1, 2]],
[[4, 1], [0, 7], [5, 0], [4, 7]],
]
)
softmax_probs = jnp.array(
[0.20, 0, 0, 0, 0, 0.80, 0, 0],
[0.68, 0, 0, 0, 0.32, 0, 0, 0],
[0.22, 0.78, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0.32, 0, 0.68, 0, 0],
],
[
[
[0.20, 0, 0, 0, 0, 0.80, 0, 0],
[0.68, 0, 0, 0, 0.32, 0, 0, 0],
[0.22, 0.78, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0.32, 0, 0.68, 0, 0],
],
[
[0, 0.26, 0.74, 0, 0, 0, 0, 0],
[0, 0.79, 0, 0, 0.21, 0, 0, 0],
[0.89, 0, 0, 0, 0, 0.11, 0, 0],
[0, 0.11, 0, 0, 0, 0, 0, 0.89],
],
[
[0, 0, 0.26, 0, 0, 0, 0.74, 0],
[0, 0, 0.88, 0.12, 0, 0, 0, 0],
[0, 0, 0.17, 0, 0.83, 0, 0, 0],
[0, 0.35, 0.65, 0, 0, 0, 0, 0],
],
[
[0, 0.47, 0, 0, 0.53, 0, 0, 0],
[0.36, 0, 0, 0, 0, 0, 0, 0.64],
[0.15, 0, 0, 0, 0, 0.85, 0, 0],
[0, 0, 0, 0, 0.18, 0, 0, 0.82],
],
]
)
[0, 0.26, 0.74, 0, 0, 0, 0, 0],
[0, 0.79, 0, 0, 0.21, 0, 0, 0],
[0.89, 0, 0, 0, 0, 0.11, 0, 0],
[0, 0.11, 0, 0, 0, 0, 0, 0.89],
],
[
[0, 0, 0.26, 0, 0, 0, 0.74, 0],
[0, 0, 0.88, 0.12, 0, 0, 0, 0],
[0, 0, 0.17, 0, 0.83, 0, 0, 0],
[0, 0.35, 0.65, 0, 0, 0, 0, 0],
],
[
[0, 0.47, 0, 0, 0.53, 0, 0, 0],
[0.36, 0, 0, 0, 0, 0, 0, 0.64],
[0.15, 0, 0, 0, 0, 0.85, 0, 0],
[0, 0, 0, 0, 0.18, 0, 0, 0.82],
],
])

# As expert_capacity_in_batch=2, so updated softmax_probs become (4 tokens were dropped):
# softmax_probs = jnp.array([[[0.20, 0, 0, 0, 0, 0.80, 0, 0],
Expand Down Expand Up @@ -238,14 +234,10 @@ def setUp(self):

def test_deepseek_routing(self):
# shape as [batch, sequence, num_experts] = [1,2,16]
gate_logits = jnp.array(
[
[
[0.20, 0.10, 0.05, 0.10, 0.10, 0.60, 0.30, 0.10, 0.80, 0.01, 0.01, 0.01, 0.05, 0.80, 0.20, 0.10],
[0.68, 0.20, 0.06, 0.03, 0.32, 0.10, 0.05, 0.02, 0.65, 0.20, 0.04, 0.01, 0.32, 0.10, 0.05, 0.02],
]
]
)
gate_logits = jnp.array([[
[0.20, 0.10, 0.05, 0.10, 0.10, 0.60, 0.30, 0.10, 0.80, 0.01, 0.01, 0.01, 0.05, 0.80, 0.20, 0.10],
[0.68, 0.20, 0.06, 0.03, 0.32, 0.10, 0.05, 0.02, 0.65, 0.20, 0.04, 0.01, 0.32, 0.10, 0.05, 0.02],
]])
pre_bias_logits = gate_logits - 0.5

# 4 groups of 1st token:
Expand Down Expand Up @@ -1402,5 +1394,73 @@ def test_prefused_vs_sparse_softmax(self):
self.assertIsNone(bias_updates)


@pytest.mark.tpu_only
class FusedMlpMoETest(unittest.TestCase):
"""Tests that prefuse_moe_weights=True and prefuse_moe_weights=False produce identical outputs for MoE."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._B = 1
self._S = 16

def setUp(self):
super().setUp()
self.rng = jax.random.PRNGKey(0)
extra_args = get_decoupled_parallelism_overrides()
self.ref_cfg = pyconfig.initialize(
[None, get_test_config_path()],
run_name="fused_mlp_moe_ref",
enable_checkpointing=False,
model_name="mixtral-8x7b",
dtype="bfloat16",
sparse_matmul=True,
megablox=True,
prefuse_moe_weights=False,
ici_expert_parallelism=jax.device_count(),
max_target_length=self._S,
per_device_batch_size=self._B,
**extra_args,
)
ref_devices = maxtext_utils.create_device_mesh(self.ref_cfg)
self.ref_mesh = Mesh(ref_devices, self.ref_cfg.mesh_axes)
self.ref_model = make_moe(self.ref_cfg, self.ref_mesh)

def _inputs(self):
return jax.random.normal(self.rng, (self._B, self._S, self.ref_cfg.base_emb_dim), dtype=jnp.bfloat16)

def test_prefuse_moe_weights_matches_unfused(self):
"""prefuse_moe_weights=True output matches prefuse_moe_weights=False with sparse_matmul (Megablox)."""
extra_args = get_decoupled_parallelism_overrides()
fused_cfg = pyconfig.initialize(
[None, get_test_config_path()],
run_name="fused_mlp_moe_fused",
enable_checkpointing=False,
model_name="mixtral-8x7b",
dtype="bfloat16",
sparse_matmul=True,
megablox=True,
prefuse_moe_weights=True,
ici_expert_parallelism=jax.device_count(),
max_target_length=self._S,
per_device_batch_size=self._B,
**extra_args,
)
fused_devices = maxtext_utils.create_device_mesh(fused_cfg)
fused_mesh = Mesh(fused_devices, fused_cfg.mesh_axes)
fused_model = make_moe(fused_cfg, fused_mesh)
copy_weights_prefused(self.ref_model, fused_model)

inputs = self._inputs()
ref_out, _, _ = self.ref_model(inputs)
fused_out, _, _ = fused_model(inputs)

np.testing.assert_allclose(
np.array(ref_out, dtype=np.float32),
np.array(fused_out, dtype=np.float32),
rtol=1e-2,
atol=1e-2,
)


if __name__ == "__main__":
unittest.main()