Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,17 @@ mlpwi: 'remat'
mlpwi_0: 'remat'
mlpwi_1: 'remat'
mlpwo: 'remat'
moe_mlpwi: 'remat'
moe_mlpwi_0: 'remat'
moe_mlpwi_1: 'remat'
moe_mlpwo: 'remat'
query_proj: 'remat'
key_proj: 'remat'
value_proj: 'remat'
qkv_proj: 'remat'
out_proj: 'remat'
query_wa_proj: 'remat'
kv_wa_proj: 'remat'
mla_q: 'remat'
mla_kv: 'remat'
attention_out: 'remat'
Expand Down
6 changes: 6 additions & 0 deletions src/maxtext/configs/pyconfig_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,15 @@ def validate_and_assign_remat_tensors(keys):
"mlpwi_0",
"mlpwi_1",
"mlpwo",
"moe_mlpwi",
"moe_mlpwi_0",
"moe_mlpwi_1",
"moe_mlpwo",
"query_proj",
"key_proj",
"value_proj",
"query_wa_proj",
"kv_wa_proj",
"out_proj",
]
assert keys["decoder_layer_input"] != "remat", "Cannot remeterialize this tensor with scan_layers=True"
Expand Down
24 changes: 24 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,9 +896,33 @@ class RematAndOffload(BaseModel):
RematLocation.REMAT,
description="Remat policy for the second MLP layer's output.",
)
moe_mlpwi: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the first MoE layer's intermediate output.",
)
moe_mlpwi_0: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the first part of a gated MoE's output.",
)
moe_mlpwi_1: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the second part of a gated MoE's output.",
)
moe_mlpwo: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the second MoE layer's output.",
)
query_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the query projection.")
key_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the key projection.")
value_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for the value projection.")
query_wa_proj: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the MLA query weighted attention projection.",
)
kv_wa_proj: RematLocation = Field(
RematLocation.REMAT,
description="Remat policy for the MLA key and value weighted attention projection.",
)
qkv_proj: RematLocation = Field(RematLocation.REMAT, description="Remat policy for fused QKV projection.")
out_proj: RematLocation = Field(
RematLocation.REMAT,
Expand Down
2 changes: 2 additions & 0 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,7 @@ def mla_query_projection(
else:
# LoRA path
low_rank_q = self.wq_a(inputs_q, out_sharding=wqa_out_sharding) # [B, L, q_lora_rank]
low_rank_q = checkpoint_name(low_rank_q, "query_wa_proj")
low_rank_q = self.q_norm(low_rank_q) # RMSNorm on low rank
low_rank_q = checkpoint_name(low_rank_q, "mla_q")
q = self.wq_b(low_rank_q, out_sharding=query_sharding) # [B, L, n_heads, qk_head_dim]
Expand Down Expand Up @@ -933,6 +934,7 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm
wka_logical_name = (KV_BATCH, LENGTH_NO_EXP, KV_LORA_UP_PROJ)
wkva_out_sharding = create_sharding(self.mesh, wka_logical_name)
low_rank = self.wkv_a(inputs, out_sharding=wkva_out_sharding)
low_rank = checkpoint_name(low_rank, "kv_wa_proj")
low_rank_main, low_rank_rope = jnp.split(low_rank, [self.kv_lora_rank], axis=-1)
low_rank_main = self.kv_norm(low_rank_main)
low_rank_main = checkpoint_name(low_rank_main, "mla_kv")
Expand Down
18 changes: 9 additions & 9 deletions src/maxtext/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,7 +1274,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up: this might affect all legacy TPU recipes/performance for MoE models. We should make an announcement after it gets merged. Thanks!

layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")

layer_w1 = gmm_fn(
x,
Expand All @@ -1288,7 +1288,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
layer_w1 = jax.lax.psum(layer_w1, "tensor_transpose")
if self.config.mlp_bias:
layer_w1 = layer_w1 + w1_bias
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
intermediate_layer = self.apply_ffn_activation(layer_w0, layer_w1)

intermediate_output = gmm_fn(
Expand All @@ -1305,7 +1305,7 @@ def get_active_sharding_axes(pspec_dim_axes, tensor_dim_index):
)
if self.config.mlp_bias:
intermediate_output = intermediate_output + wo_bias
intermediate_output = adc.checkpoint_name(intermediate_output, "mlpwo")
intermediate_output = adc.checkpoint_name(intermediate_output, "moe_mlpwo")

if self.config.use_ring_of_experts:
# Set the outputs of tokens which were not processed to 0.
Expand Down Expand Up @@ -1860,7 +1860,7 @@ def dense_matmul(
layer_w0,
mlp_axis,
)
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
with jax.named_scope("wi_1"):
w1_kernel_axes = ("exp", None, "mlp")
w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes)
Expand All @@ -1876,7 +1876,7 @@ def dense_matmul(
layer_w1,
mlp_axis,
)
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)
with jax.named_scope("wo"):
wo_kernel_axes = ("exp", "mlp", None)
Expand All @@ -1902,7 +1902,7 @@ def dense_matmul(
"activation_embed",
),
)
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo")
with jax.named_scope("combine"):
# Matmul & element wise operation
output = self.get_einsum(rhs_mesh_axes=mask_axes, einsum_name=COMBINE)(
Expand Down Expand Up @@ -1931,7 +1931,7 @@ def dense_matmul(
layer_w0 = layer_w0 + w0_bias[None, None, :, :]
if self.config.activations_in_float32:
layer_w0 = layer_w0.astype(jnp.float32)
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
layer_w0 = adc.checkpoint_name(layer_w0, "moe_mlpwi_0")
with jax.named_scope("wi_1"):
layer_w1 = self.get_einsum(rhs_mesh_axes=self.wi_kernel_axes)(
"BSM,EMH -> BSEH", inputs, w1_kernel, precision=matmul_precision
Expand All @@ -1940,7 +1940,7 @@ def dense_matmul(
layer_w1 = layer_w1 + w1_bias[None, None, :, :]
if self.config.activations_in_float32:
layer_w1 = layer_w1.astype(jnp.float32)
layer_w1 = adc.checkpoint_name(layer_w1, "mlpwi_1")
layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1")
layer_multiply = self.apply_ffn_activation(layer_w0, layer_w1)

with jax.named_scope("wo"):
Expand All @@ -1954,7 +1954,7 @@ def dense_matmul(
intermediate_layer = intermediate_layer + wo_bias[None, None, :, :]
if self.config.activations_in_float32:
intermediate_layer = intermediate_layer.astype(jnp.float32)
intermediate_layer = adc.checkpoint_name(intermediate_layer, "mlpwo")
intermediate_layer = adc.checkpoint_name(intermediate_layer, "moe_mlpwo")
with jax.named_scope("weight_sum"):
if is_llama4_decoder_layer:
weights = self.reshape_and_update_weights(jnp.ones_like(top_k_weights), top_k_indices)
Expand Down
Loading