Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,16 @@ mlpwi: 'remat'
mlpwi_0: 'remat'
mlpwi_1: 'remat'
mlpwo: 'remat'
moe_mlpwi_0: 'remat'
moe_mlpwi_1: 'remat'
moe_mlpwo: 'remat'
query_proj: 'remat'
key_proj: 'remat'
value_proj: 'remat'
qkv_proj: 'remat'
out_proj: 'remat'
query_wa_proj: 'remat'
kv_wa_proj: 'remat'
mla_q: 'remat'
mla_kv: 'remat'
attention_out: 'remat'
Expand Down
5 changes: 5 additions & 0 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,14 @@ def validate_and_assign_remat_tensors(keys):
"mlpwi_0",
"mlpwi_1",
"mlpwo",
"moe_mlpwi_0",
"moe_mlpwi_1",
"moe_mlpwo",
"query_proj",
"key_proj",
"value_proj",
"query_wa_proj",
"kv_wa_proj",
"out_proj",
]
assert keys["decoder_layer_input"] != "remat", "Cannot remeterialize this tensor with scan_layers=True"
Expand Down
20 changes: 20 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,9 +896,29 @@ class RematAndOffload(BaseModel):
RematLocation.REMAT,
description="Remat policy for the second MLP layer's output.",
)
moe_mlpwi_0: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the first part of a gated MoE's output.",
)
moe_mlpwi_1: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the second part of a gated MoE's output.",
)
moe_mlpwo: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the second MoE layer's output.",
)
query_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the query projection.")
key_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the key projection.")
value_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the value projection.")
query_wa_proj: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the MLA query weighted attention projection.",
)
kv_wa_proj: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the MLA key and value weighted attention projection.",
)
qkv_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for fused QKV projection.")
out_proj: RematLocation = Field(
RematLocation.REMAT,
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ def mla_query_projection(
else:
# LoRA path
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
low_rank_q = checkpoint_name(low_rank_q, "query_wa_proj")
low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank
low_rank_q = checkpoint_name(low_rank_q, "mla_q")
q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads, qk_head_dim]
Expand Down Expand Up @@ -933,6 +934,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
low_rank = checkpoint_name(low_rank, "kv_wa_proj")
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
low_rank_main = self.kv_norm(low_rank_main)
low_rank_main = checkpoint_name(low_rank_main, "mla_kv")
Expand Down
18 changes: 9 additions & 9 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up: this might affect all legacy TPU recipes/performance for MoE models. We should make an announcement after it gets merged. Thanks!

layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")

layer_w1 = gmm_fn(
x,
Expand All @@ -1288,7 +1288,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
if self.config.mlp_bias:
layer_w1 = layer_w1 + w1_bias
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)

intermediate_output = gmm_fn(
Expand All @@ -1305,7 +1305,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
)
if self.config.mlp_bias:
intermediate_output = intermediate_output + wo_bias
intermediate_output = adc.checkpoint_name(intermediate_output, "mlpwo")
intermediate_output = adc.checkpoint_name(intermediate_output, "moe_mlpwo")

if self.config.use_ring_of_experts:
# Set the outputs of tokens which were not processed to 0.
Expand Down Expand Up @@ -1860,7 +1860,7 @@ def dense_matmul(
layer_w0,
mlp_axis,
)
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
with jax.named_scope("wi_1"):
w1_kernel_axes = ("exp", None, "mlp")
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
Expand All @@ -1876,7 +1876,7 @@ def dense_matmul(
layer_w1,
mlp_axis,
)
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
with jax.named_scope("wo"):
wo_kernel_axes = ("exp", "mlp", None)
Expand All @@ -1902,7 +1902,7 @@ def dense_matmul(
"activation_embed",
),
)
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo")
with jax.named_scope("combine"):
# Matmul & element wise operation
output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)(
Expand Down Expand Up @@ -1931,7 +1931,7 @@ def dense_matmul(
layer_w0 = layer_w0 + w0_bias[None, None, :, :]
if self.config.activations_in_float32:
layer_w0 = layer_w0.astype(jnp.float32)
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
with jax.named_scope("wi_1"):
layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)(
"BSM,EMH -> BSEH", inputs, w1_kernel, precision=matmul_precision
Expand All @@ -1940,7 +1940,7 @@ def dense_matmul(
layer_w1 = layer_w1 + w1_bias[None, None, :, :]
if self.config.activations_in_float32:
layer_w1 = layer_w1.astype(jnp.float32)
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)

with jax.named_scope("wo"):
Expand All @@ -1954,7 +1954,7 @@ def dense_matmul(
intermediate_layer = intermediate_layer + wo_bias[None, None, :, :]
if self.config.activations_in_float32:
intermediate_layer = intermediate_layer.astype(jnp.float32)
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo")
with jax.named_scope("weight_sum"):
if is_llama4_decoder_layer:
weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices)
Expand Down
122 changes: 51 additions & 71 deletions tests/unit/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,62 +55,52 @@ def test_one_block_mask(self):
bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0]])
# pylint: disable=protected-access
block_mask = _make_bidirectional_block_mask(bidirectional_mask)
expected_mask = np.asarray(
[
[
[False, False, False, False, False, False],
[False, True, True, True, False, False],
[False, True, True, True, False, False],
[False, True, True, True, False, False],
[False, False, False, False, False, False],
[False, False, False, False, False, False],
]
]
)
expected_mask = np.asarray([[
[False, False, False, False, False, False],
[False, True, True, True, False, False],
[False, True, True, True, False, False],
[False, True, True, True, False, False],
[False, False, False, False, False, False],
[False, False, False, False, False, False],
]])
np.testing.assert_array_equal(block_mask, expected_mask)

def test_two_blocks_mask(self):
bidirectional_mask = np.asarray([[0, 1, 1, 0, 1, 1]])
# pylint: disable=protected-access
block_mask = _make_bidirectional_block_mask(bidirectional_mask)
expected_mask = np.asarray(
[
[
[False, False, False, False, False, False],
[False, True, True, False, False, False],
[False, True, True, False, False, False],
[False, False, False, False, False, False],
[False, False, False, False, True, True],
[False, False, False, False, True, True],
]
]
)
expected_mask = np.asarray([[
[False, False, False, False, False, False],
[False, True, True, False, False, False],
[False, True, True, False, False, False],
[False, False, False, False, False, False],
[False, False, False, False, True, True],
[False, False, False, False, True, True],
]])
np.testing.assert_array_equal(block_mask, expected_mask)

def test_batch_block_masks(self):
bidirectional_mask = np.asarray([[0, 1, 1, 1, 0, 0], [0, 1, 1, 0, 1, 1]])
# pylint: disable=protected-access
block_mask = _make_bidirectional_block_mask(bidirectional_mask)
expected_mask = np.asarray(
expected_mask = np.asarray([
[
[
[False, False, False, False, False, False],
[False, True, True, True, False, False],
[False, True, True, True, False, False],
[False, True, True, True, False, False],
[False, False, False, False, False, False],
[False, False, False, False, False, False],
],
[
[False, False, False, False, False, False],
[False, True, True, False, False, False],
[False, True, True, False, False, False],
[False, False, False, False, False, False],
[False, False, False, False, True, True],
[False, False, False, False, True, True],
],
]
)
[False, False, False, False, False, False],
[False, True, True, True, False, False],
[False, True, True, True, False, False],
[False, True, True, True, False, False],
[False, False, False, False, False, False],
[False, False, False, False, False, False],
],
[
[False, False, False, False, False, False],
[False, True, True, False, False, False],
[False, True, True, False, False, False],
[False, False, False, False, False, False],
[False, False, False, False, True, True],
[False, False, False, False, True, True],
],
])
np.testing.assert_array_equal(block_mask, expected_mask)

def test_empty_block_mask(self):
Expand Down Expand Up @@ -140,34 +130,24 @@ def test_combine_with_causal_mask(self):
# pylint: disable=protected-access
image_mask = _make_bidirectional_block_mask(bidirectional_mask)
combined_mask = causal_mask | image_mask[:, None, None, ...]
expected_mask = np.asarray(
[
[
[
[
[True, False, False, False, False, False],
[True, True, True, True, False, False],
[True, True, True, True, False, False],
[True, True, True, True, False, False],
[True, True, True, True, True, False],
[True, True, True, True, True, True],
]
]
],
[
[
[
[True, False, False, False, False, False],
[True, True, True, False, False, False],
[True, True, True, False, False, False],
[True, True, True, True, False, False],
[True, True, True, True, True, True],
[True, True, True, True, True, True],
]
]
],
]
)
expected_mask = np.asarray([
[[[
[True, False, False, False, False, False],
[True, True, True, True, False, False],
[True, True, True, True, False, False],
[True, True, True, True, False, False],
[True, True, True, True, True, False],
[True, True, True, True, True, True],
]]],
[[[
[True, False, False, False, False, False],
[True, True, True, False, False, False],
[True, True, True, False, False, False],
[True, True, True, True, False, False],
[True, True, True, True, True, True],
[True, True, True, True, True, True],
]]],
])
np.testing.assert_array_equal(combined_mask, expected_mask)


Expand Down