From d7fd385efdbcb165c96d6bfafacfacf300eae434 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Thu, 5 Mar 2026 15:24:30 -0800 Subject: [PATCH 1/3] Add MoE and MLA remat policies --- src/maxtext/configs/base.yml | 6 ++++++ src/maxtext/configs/pyconfig_deprecated.py | 6 ++++++ src/maxtext/configs/types.py | 24 ++++++++++++++++++++++ src/maxtext/layers/attention_mla.py | 2 ++ src/maxtext/layers/moe.py | 18 ++++++++-------- 5 files changed, 47 insertions(+), 9 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 6772678dfc..b43861bbd6 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -317,11 +317,17 @@ mlpwi: 'remat' mlpwi_0: 'remat' mlpwi_1: 'remat' mlpwo: 'remat' +moe_mlpwi: 'remat' +moe_mlpwi_0: 'remat' +moe_mlpwi_1: 'remat' +moe_mlpwo: 'remat' query_proj: 'remat' key_proj: 'remat' value_proj: 'remat' qkv_proj: 'remat' out_proj: 'remat' +query_wa_proj: 'remat' +kv_wa_proj: 'remat' mla_q: 'remat' mla_kv: 'remat' attention_out: 'remat' diff --git a/src/maxtext/configs/pyconfig_deprecated.py b/src/maxtext/configs/pyconfig_deprecated.py index 2fc7ab82a7..d3183b98f7 100644 --- a/src/maxtext/configs/pyconfig_deprecated.py +++ b/src/maxtext/configs/pyconfig_deprecated.py @@ -516,9 +516,15 @@ def validate_and_assign_remat_tensors(keys): "mlpwi_0", "mlpwi_1", "mlpwo", + "moe_mlpwi", + "moe_mlpwi_0", + "moe_mlpwi_1", + "moe_mlpwo", "query_proj", "key_proj", "value_proj", + "query_wa_proj", + "kv_wa_proj", "out_proj", ] assert keys["decoder_layer_input"] != "remat", "Cannot remeterialize this tensor with scan_layers=True" diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c9e9475883..57f604c1ee 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -896,9 +896,33 @@ class RematAndOffload(BaseModel): RematLocation.REMAT, description="Remat policy for the second MLP layer's output.", ) + moe_mlpwi: RematLocation = Field( + RematLocation.REMAT, + description="Remat policy for the first MoE layer's intermediate output.", + ) + moe_mlpwi_0: RematLocation = Field( + RematLocation.REMAT, + description="Remat policy for the first part of a gated MoE's output.", + ) + moe_mlpwi_1: RematLocation = Field( + RematLocation.REMAT, + description="Remat policy for the second part of a gated MoE's output.", + ) + moe_mlpwo: RematLocation = Field( + RematLocation.REMAT, + description="Remat policy for the second MoE layer's output.", + ) query_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the query projection.") key_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the key projection.") value_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the value projection.") + query_wa_proj: RematLocation = Field( + RematLocation.REMAT, + description="Remat policy for the MLA query weighted attention projection.", + ) + kv_wa_proj: RematLocation = Field( + RematLocation.REMAT, + description="Remat policy for the MLA key and value weighted attention projection.", + ) qkv_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for fused QKV projection.") out_proj: RematLocation = Field( RematLocation.REMAT, diff --git a/src/maxtext/layers/attention_mla.py b/src/maxtext/layers/attention_mla.py index e0d6e4e9f1..dd6b58f68c 100644 --- a/src/maxtext/layers/attention_mla.py +++ b/src/maxtext/layers/attention_mla.py @@ -794,6 +794,7 @@ def mla_query_projection( else: # LoRA path low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank] + low_rank_q = checkpoint_name(low_rank_q, "query_wa_proj") low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank low_rank_q = checkpoint_name(low_rank_q, "mla_q") q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads, qk_head_dim] @@ -933,6 +934,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ) wkva_out_sharding = create_sharding(self.mesh, wka_logical_name) low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding) + low_rank = checkpoint_name(low_rank, "kv_wa_proj") low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1) low_rank_main = self.kv_norm(low_rank_main) low_rank_main = checkpoint_name(low_rank_main, "mla_kv") diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index e7f548c847..a887cc384b 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1274,7 +1274,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose") if self.config.mlp_bias: layer_w0 = layer_w0 + w0_bias - layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0") + layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0") layer_w1 = gmm_fn( x, @@ -1288,7 +1288,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose") if self.config.mlp_bias: layer_w1 = layer_w1 + w1_bias - layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1") + layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1") intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1) intermediate_output = gmm_fn( @@ -1305,7 +1305,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index): ) if self.config.mlp_bias: intermediate_output = intermediate_output + wo_bias - intermediate_output = adc.checkpoint_name(intermediate_output, "mlpwo") + intermediate_output = adc.checkpoint_name(intermediate_output, "moe_mlpwo") if self.config.use_ring_of_experts: # Set the outputs of tokens which were not processed to 0. @@ -1860,7 +1860,7 @@ def dense_matmul( layer_w0, mlp_axis, ) - layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0") + layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0") with jax.named_scope("wi_1"): w1_kernel_axes = ("exp", None, "mlp") w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes) @@ -1876,7 +1876,7 @@ def dense_matmul( layer_w1, mlp_axis, ) - layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1") + layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1") layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1) with jax.named_scope("wo"): wo_kernel_axes = ("exp", "mlp", None) @@ -1902,7 +1902,7 @@ def dense_matmul( "activation_embed", ), ) - intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo") + intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo") with jax.named_scope("combine"): # Matmul & element wise operation output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)( @@ -1931,7 +1931,7 @@ def dense_matmul( layer_w0 = layer_w0 + w0_bias[None, None, :, :] if self.config.activations_in_float32: layer_w0 = layer_w0.astype(jnp.float32) - layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0") + layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0") with jax.named_scope("wi_1"): layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)( "BSM,EMH -> BSEH", inputs, w1_kernel, precision=matmul_precision @@ -1940,7 +1940,7 @@ def dense_matmul( layer_w1 = layer_w1 + w1_bias[None, None, :, :] if self.config.activations_in_float32: layer_w1 = layer_w1.astype(jnp.float32) - layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1") + layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1") layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1) with jax.named_scope("wo"): @@ -1954,7 +1954,7 @@ def dense_matmul( intermediate_layer = intermediate_layer + wo_bias[None, None, :, :] if self.config.activations_in_float32: intermediate_layer = intermediate_layer.astype(jnp.float32) - intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo") + intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo") with jax.named_scope("weight_sum"): if is_llama4_decoder_layer: weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices) From 5109d24fe4c3ced648f40700e909f76c40538c23 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Tue, 24 Mar 2026 11:48:24 -0700 Subject: [PATCH 2/3] Fix pyink formatting in attention_test.py --- tests/unit/attention_test.py | 122 +++++++++++++++-------------------- 1 file changed, 51 insertions(+), 71 deletions(-) diff --git a/tests/unit/attention_test.py b/tests/unit/attention_test.py index fc2c3c2d24..487b57e41d 100644 --- a/tests/unit/attention_test.py +++ b/tests/unit/attention_test.py @@ -55,62 +55,52 @@ def test_one_block_mask(self): bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0]]) # pylint: disable=protected-access block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.asarray( - [ - [ - [False, False, False, False, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, False, False, False, False, False], - [False, False, False, False, False, False], - ] - ] - ) + expected_mask = np.asarray([[ + [False, False, False, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ]]) np.testing.assert_array_equal(block_mask, expected_mask) def test_two_blocks_mask(self): bidirectional_mask = np.asarray([[0, 1, 1, 0, 1, 1]]) # pylint: disable=protected-access block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.asarray( - [ - [ - [False, False, False, False, False, False], - [False, True, True, False, False, False], - [False, True, True, False, False, False], - [False, False, False, False, False, False], - [False, False, False, False, True, True], - [False, False, False, False, True, True], - ] - ] - ) + expected_mask = np.asarray([[ + [False, False, False, False, False, False], + [False, True, True, False, False, False], + [False, True, True, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + ]]) np.testing.assert_array_equal(block_mask, expected_mask) def test_batch_block_masks(self): bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]]) # pylint: disable=protected-access block_mask = _make_bidirectional_block_mask(bidirectional_mask) - expected_mask = np.asarray( + expected_mask = np.asarray([ [ - [ - [False, False, False, False, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, True, True, True, False, False], - [False, False, False, False, False, False], - [False, False, False, False, False, False], - ], - [ - [False, False, False, False, False, False], - [False, True, True, False, False, False], - [False, True, True, False, False, False], - [False, False, False, False, False, False], - [False, False, False, False, True, True], - [False, False, False, False, True, True], - ], - ] - ) + [False, False, False, False, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, True, True, True, False, False], + [False, False, False, False, False, False], + [False, False, False, False, False, False], + ], + [ + [False, False, False, False, False, False], + [False, True, True, False, False, False], + [False, True, True, False, False, False], + [False, False, False, False, False, False], + [False, False, False, False, True, True], + [False, False, False, False, True, True], + ], + ]) np.testing.assert_array_equal(block_mask, expected_mask) def test_empty_block_mask(self): @@ -140,34 +130,24 @@ def test_combine_with_causal_mask(self): # pylint: disable=protected-access image_mask = _make_bidirectional_block_mask(bidirectional_mask) combined_mask = causal_mask | image_mask[:, None, None, ...] - expected_mask = np.asarray( - [ - [ - [ - [ - [True, False, False, False, False, False], - [True, True, True, True, False, False], - [True, True, True, True, False, False], - [True, True, True, True, False, False], - [True, True, True, True, True, False], - [True, True, True, True, True, True], - ] - ] - ], - [ - [ - [ - [True, False, False, False, False, False], - [True, True, True, False, False, False], - [True, True, True, False, False, False], - [True, True, True, True, False, False], - [True, True, True, True, True, True], - [True, True, True, True, True, True], - ] - ] - ], - ] - ) + expected_mask = np.asarray([ + [[[ + [True, False, False, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, False, False], + [True, True, True, True, True, False], + [True, True, True, True, True, True], + ]]], + [[[ + [True, False, False, False, False, False], + [True, True, True, False, False, False], + [True, True, True, False, False, False], + [True, True, True, True, False, False], + [True, True, True, True, True, True], + [True, True, True, True, True, True], + ]]], + ]) np.testing.assert_array_equal(combined_mask, expected_mask) From 0ac53b5ce37b82ab7fa0b68dc6811e5e2c2435c0 Mon Sep 17 00:00:00 2001 From: Abhinav Goel Date: Tue, 24 Mar 2026 22:47:51 +0000 Subject: [PATCH 3/3] Remove moe_mlpwi for now --- src/maxtext/configs/base.yml | 1 - src/maxtext/configs/pyconfig_deprecated.py | 1 - src/maxtext/configs/types.py | 4 ---- 3 files changed, 6 deletions(-) diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index b43861bbd6..91b28968ac 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -317,7 +317,6 @@ mlpwi: 'remat' mlpwi_0: 'remat' mlpwi_1: 'remat' mlpwo: 'remat' -moe_mlpwi: 'remat' moe_mlpwi_0: 'remat' moe_mlpwi_1: 'remat' moe_mlpwo: 'remat' diff --git a/src/maxtext/configs/pyconfig_deprecated.py b/src/maxtext/configs/pyconfig_deprecated.py index d3183b98f7..eb452e4739 100644 --- a/src/maxtext/configs/pyconfig_deprecated.py +++ b/src/maxtext/configs/pyconfig_deprecated.py @@ -516,7 +516,6 @@ def validate_and_assign_remat_tensors(keys): "mlpwi_0", "mlpwi_1", "mlpwo", - "moe_mlpwi", "moe_mlpwi_0", "moe_mlpwi_1", "moe_mlpwo", diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index 57f604c1ee..b3660427c9 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -896,10 +896,6 @@ class RematAndOffload(BaseModel): RematLocation.REMAT, description="Remat policy for the second MLP layer's output.", ) - moe_mlpwi: RematLocation = Field( - RematLocation.REMAT, - description="Remat policy for the first MoE layer's intermediate output.", - ) moe_mlpwi_0: RematLocation = Field( RematLocation.REMAT, description="Remat policy for the first part of a gated MoE's output.",