Skip to content

Add MoE and MLA remat policies#3414

Open
abhinavgoel95 wants to merge 2 commits intoAI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/add-moe-mla-remat-policies
Open

Add MoE and MLA remat policies#3414
abhinavgoel95 wants to merge 2 commits intoAI-Hypercomputer:mainfrom
abhinavgoel95:abgoel/add-moe-mla-remat-policies

Conversation

@abhinavgoel95
Copy link
Contributor

@abhinavgoel95 abhinavgoel95 commented Mar 13, 2026

  • Added moe_mlpwi, moe_mlpwi_0, moe_mlpwi_1, moe_mlpwo for MoE layers
  • Added query_wa_proj, kv_wa_proj for MLA layers
  • Updated base.yml, types.py, and pyconfig_deprecated.py

Description

This PR adds rematerialization policy support for Mixture of Experts (MoE) and Multi-head Latent Attention (MLA) layer tensors.

Previously, MaxText only supported remat policies for standard dense layer tensors. This prevented fine-grained memory optimization for MoE models (like Mixtral, DeepSeek V3) and models using MLA architecture (like DeepSeek V3).

This change adds six new configurable remat tensors:

  • MoE tensors: moe_mlpwi, moe_mlpwi_0, moe_mlpwi_1, moe_mlpwo
  • MLA tensors: query_wa_proj, kv_wa_proj

Users can now configure these tensors with device, offload, or remat policies in their config files, enabling better memory management for large MoE models (e.g., DeepSeek V3 671B).

Files modified:

  • src/maxtext/configs/base.yml - Added default 'remat' values
  • src/maxtext/configs/types.py - Added Field definitions with descriptions
  • src/maxtext/configs/pyconfig_deprecated.py - Added to validation whitelist

All new tensors default to 'remat', maintaining backward compatibility.

Tests

Tested with DeepSeek V3 671B (41 layers) on 128 GPUs with various remat configurations:

  • Baseline with all tensors set to remat - ✅ Works
  • Custom policies with selective offload and device placement - ✅ Works
  • Verified backward compatibility with Llama models (no regression)

Example config usage:

moe_mlpwi: 'offload'
query_wa_proj: 'device'

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in https://maxtext.readthedocs.io/en/latest/development.html#adding-new-documentation-files.

layer_w0 = jax.lax.psum(layer_w0, "tensor_transpose")
if self.config.mlp_bias:
layer_w0 = layer_w0 + w0_bias
layer_w0 = adc.checkpoint_name(layer_w0, "mlpwi_0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heads up: this might affect all legacy TPU recipes/performance for MoE models. We should make an announcement after it gets merged. Thanks!

@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/add-moe-mla-remat-policies branch from a3779fc to d7fd385 Compare March 16, 2026 17:52
@codecov
Copy link

codecov bot commented Mar 19, 2026

Codecov Report

❌ Patch coverage is 87.50000% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/layers/attention_mla.py 50.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@abhinavgoel95 abhinavgoel95 force-pushed the abgoel/add-moe-mla-remat-policies branch from 2fc4407 to 5109d24 Compare March 24, 2026 18:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants