Skip to content

checkpoint utility: optimize to_maxtext, add deepseek#3184

Open
shuningjin wants to merge 1 commit intomainfrom
shuningjin-ckpt-opt3
Open

checkpoint utility: optimize to_maxtext, add deepseek#3184
shuningjin wants to merge 1 commit intomainfrom
shuningjin-ckpt-opt3

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Feb 18, 2026

Description

Optimize to_maxtext loading and saving

  • The main goal is to optimize the to_maxtext eager loading and saving pipelines.
  • The optimization unblocks large-scale models. We additionally onboards DeepSeek model family (deepseek2-16b, deepseek3-671b, deepseek3.2-671b)
  • See this document for detailed logic and tests (http://shortn/_KlHIRwUxvI).
  • Fix: b/452391831 (memory), b/477316979 (speed), ds3.2 (b/469550012, b/469550011), ds2 (b/459536844), ds3 (b/457820372, b/457820735), ds (b/452392346)

What Changed
Previously, eager load defaulted to transformers_class.from_pretrained(...), which loaded, converted, and saved checkpoints in float32. (If using latest transformers version, e2e run would have error.)

This PR introduces two optimized loading methods and adds the ability to save in bfloat16:

  • Method 1: transformers_class.from_pretrained(..., dtype="auto") to load the original tensor type.
  • Method 2: safetensors.safe_open(..., framework="pt") to load natively from safetensors. Similar to Method 1, this can either process remote repo or local path.
  • Save Dtype: Added bfloat16 as the recommended save option (with float32 retained as a backup).
# to_maxtext flag
--eager_load_method=<transformers (default) | safetensors>
--save_dtype=<bfloat16 (default) | float32>

Why It Matters (Impact & Benefits)

  1. 2x Memory Reduction: Peak memory usage is cut in half across the board.
  • gpt-oss-120b: 1009.72 GB -> 511.47 GB
  • approx 8y -> 4y GB, where y is billion parameters
  1. Speedups for Loading Alone: Loading time is drastically reduced by avoiding native float32 casting and NumPy bottlenecks.
  • gpt-oss-120b: 78 min -> 1s
  • deepseek3-671b: 7.5 hr -> 4 min
  1. Speedup for Conversion Total: Reduce total conversion lifecyle for large model
  • gpt-oss-120b: 134.86 min -> 96.22 min.
  1. Unblocked Scalability: Entire conversions for massive models are now practical. deepseek3-671b previously OOM'd on 3.7TB RAM; it is now feasible with a peak of 2854.90 GB and a total conversion time of ~9.5 hours.
  2. Reduced Storage: Checkpoint sizes are smaller (e.g., gpt-oss-120b dropped from 100.17 GiB to 74.23 GiB).
  3. Increased Flexibility: Method 2 (safetensors) allows us to
  • convert models even if the HuggingFace code isn't fully available yet (e.g., deepseek-ai/DeepSeek-V3.2).
  • convert weights omitted by Transformers class (e.g., Multi-Token-Prediction weights layers.61 is not loaded by deepseek-ai/DeepSeek-V3)

Other changes

  • to_maxtext: Reuse HF_MODEL_CONFIGS rather than transformers.AutoConfig. This accommodates model without full HuggingFace code support (e.g., deepseek3.2). This also aligns with how to_huggingface uses config.
  • to_huggingface: Initially, maxtext weights are loaded via set_decode_state, which uses config.weight_dtype. It was subsequently changed to orbax restore, which loads the weight as is. To control save dtype, we now explicitly cast it to config.weight_type in utils._process.
# to_huggingface flag
config.weight_type=<bfloat16 | float32 (default)>

Tests

Test details in doc.

1. Performance (gpt-oss-120b)

  • Converted hf-bf16 to maxtext scanned.
  • Compared the previous method against Method 2 + bfloat16 save.
  • Result: Verified 2x memory reduction, significant total conversion time reduction (134.86 min -> 96.22 min), smaller checkpoint size, and confirmed logit precision remains very close.

2. Functionality (qwen3-0.6b)

  • Converted hf-bf16 to maxtext scanned.
  • Tested the matrix of {lazy, method1, method2} x {bfloat16, float32}.
  • Result: Verified all different load/save modes are functional and yield correct logic via logit checks.

3. Scalability (deepseek3-671b)

  • Tested {to_maxtext} x {scanned}.
  • Result: Successfully scaled to_maxtext for this model class. (Previously, only to_huggingface was feasible due to OOM constraints).

4. New DeepSeek Mappings

  • deepseek2-16b: {to_maxtext, to_huggingface} x {scanned, unscanned}
  • deepseek3-671b: {to_maxtext} x {unscanned}
  • deepseek3.2: {to_maxtext} x {scanned, unscanned}. Note to_huggingface is not enabled as DeepSeek32ForCausalLM is not supported yet.
  • Result: Verified new mappings are fully functional and correct via logit checks.

Examples:

# to_maxtext, eager load with transformers (default), save as bfloat16
BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
python3 -m maxtext.checkpoint_conversion.to_maxtext \
src/maxtext/configs/base.yml model_name=deepseek2-16b scan_layers=true attention=dot_product \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
--eager_load_method=transformers --save_dtype=bfloat16

# to_maxtext, eager load with safetensors, save as bfloat16
BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
python3 -m maxtext.checkpoint_conversion.to_maxtext \
src/maxtext/configs/base.yml model_name=deepseek2-16b scan_layers=true attention=dot_product \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
--eager_load_method=safetensors --save_dtype=bfloat16

# to_maxtext, eager load with safetensors, save as bfloat16
BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
python3 -m maxtext.checkpoint_conversion.to_maxtext \
src/maxtext/configs/base.yml model_name=deepseek3.2-671b scan_layers=true attention=dot_product \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
--eager_load_method=safetensors --save_dtype=bfloat16 \
--hf_model_path=$CUSTOM_PATH_HF_BF16 # original fp8 checkpoint dequantized to bf16

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Mar 18, 2026

Codecov Report

❌ Patch coverage is 12.19512% with 36 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...xtext/checkpoint_conversion/utils/param_mapping.py 0.00% 33 Missing ⚠️
...xt/checkpoint_conversion/utils/hf_model_configs.py 71.42% 1 Missing and 1 partial ⚠️
src/maxtext/utils/muon_utils.py 0.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@shuningjin shuningjin force-pushed the shuningjin-ckpt-opt3 branch from 0798438 to 5702326 Compare March 24, 2026 21:36
@shuningjin shuningjin changed the title Checkpoint conversion tool: Optimize to_maxtext & Onboard deepseek2/3/3.2 checkpoint utility: optimize to_maxtext, add deepseek Mar 24, 2026
@shuningjin shuningjin marked this pull request as ready for review March 24, 2026 21:53
@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions bot Mar 25, 2026
@github-actions
Copy link

🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces significant improvements and optimizations to the checkpoint conversion process in MaxText, specifically focusing on DeepSeek model support (V2-16B, V3-671B, and V3.2-671B). The implementation of LazyHFLoader and the adoption of dtype="auto" in Hugging Face loading are excellent additions that substantially reduce memory overhead, making the conversion of extremely large models more feasible.

🔍 General Feedback

  • Efficiency: The shift towards memory-efficient loading strategies is a major highlight. Using safetensors on-demand avoids redundant memory consumption during the to_maxtext conversion.
  • Support: Comprehensive support for DeepSeek's MLA architecture and MoE experts is well-integrated into both hf_shape.py and param_mapping.py.
  • Maintainability: The refactoring of forward_pass_logit_checker.py and the grouping of reshape hooks in param_mapping.py significantly improve code clarity and ease of future extension.


hf_model = model_class.from_pretrained(model_id, token=token, revision=revision)
return hf_model
hf_model = model_class.from_pretrained(model_id, token=token, revision=revision, dtype=dtype)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 The `AutoModelForCausalLM.from_pretrained` method in `transformers` uses the `torch_dtype` parameter to specify the precision (e.g., `torch_dtype="auto"`). Using `dtype` as a keyword argument will likely be passed into the model's `__init__` via `**kwargs` and may be ignored or cause unexpected behavior depending on the model's implementation. To ensure the optimization goal of loading in the original tensor type works as intended, use `torch_dtype`.
Suggested change
hf_model = model_class.from_pretrained(model_id, token=token, revision=revision, dtype=dtype)
hf_model = model_class.from_pretrained(model_id, token=token, revision=revision, torch_dtype=dtype)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you cross check?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should stick to "dtype". I am seeing deprecation warning when using "torch_dtype" in model loading.

  • Added a comment
# Note: transformers deprecates `torch_dtype` in favor of standard `dtype` in model loading
  • test
! pip install transformers==4.57.3 
import transformers
print(transformers.__version__)
from transformers import AutoModelForCausalLM
AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6b", torch_dtype="bfloat16")
4.57.3
  warnings.warn(
`torch_dtype` is deprecated! Use `dtype` instead!

Comment on lines 1757 to 1762
def reshape_audio_attn_out(input_tensor, target_shape):
"""Reshape audio attention output projection.
F
HF: (hidden_size, hidden_size)
MaxText: (num_heads, head_dim, hidden_size)

HF: (hidden_size, hidden_size)
MaxText: (num_heads, head_dim, hidden_size)
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 There is a typo (extra F) on a single line in the docstring for reshape_audio_attn_out.

Suggested change
def reshape_audio_attn_out(input_tensor, target_shape):
"""Reshape audio attention output projection.
F
HF: (hidden_size, hidden_size)
MaxText: (num_heads, head_dim, hidden_size)
HF: (hidden_size, hidden_size)
MaxText: (num_heads, head_dim, hidden_size)
"""
def reshape_audio_attn_out(input_tensor, target_shape):
\"\"\"Reshape audio attention output projection.
HF: (hidden_size, hidden_size)
MaxText: (num_heads, head_dim, hidden_size)
\"\"\"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was removed previously.

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM at high level! A few minor comments.

| **DeepSeek3** | 671B | - | - | √ | - |
| **DeepSeek2** | 16B | √ | √ | √ | √ |
| **DeepSeek3** | 671B | √ | √ | √ | √ |
| **DeepSeek3.2** | 671B | √ | √ | - | - |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is Orbax (scan/unscan) → HF a follow-up PR? This could be helpful for the RL if needed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For to_huggingface

  • the current deepseek3.2 param/hook mapping should work out of box (expect no code change)
  • currently, have error in saving config.json, as DeepseekV32ForCausalLM is missing from transformers library (still in PR).

Next step:

  • Verify to huggingface conversion after PR 41251 is merged in transformers. Also logit check might be easier.
  • Created b/496411531 to follow-up on the verification

# Sharding axis 0: Omit log for brevity per the summary log above.
return jax.device_put(arr, device=s1)
elif len(arr.shape) > 1 and arr.shape[1] % device_count == 0:
max_logging.log("sharding axis 1")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these logs actually helpful? Since they don't specify which tensor is being sharded and only show '0, 1, etc', they add a lot of noise. Should we downgrade them to DEBUG with tensor shape?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TLDR: I added tensor shape logging. I prefer to keep the current logging level as it is lightweight.

The current logging has omitted "Axis 0 sharding" (which is the most cases); previously it log every tensor. As a result, the log for sharding is relatively lightweight. Examples:

deepseek3.2 (most other models look like this)

INFO:absl:Note: Axis 0 sharding is the default and will not be logged individually.        
100%|██████████████████████████████| 42/42 [30:13<00:00, 43.18s/it]
INFO:absl:Elapse for checkpoint sharding: 30.24 min                                        

gpt-oss-20b scanned

INFO:absl:Note: Axis 0 sharding is the default and will not be logged individually.
  2%|████▍                                                                                                                                                                                 | 1/41 [00:00<00:25,  1.57it/s]
INFO:absl:Not sharding. Shape (8, 12, 64)
 20%|███████████████████████████████████▌                                                                                                                                                  | 8/41 [00:04<00:20,  1.63it/s]
INFO:absl:Not sharding. Shape (8, 12, 64)
 49%|████████████████████████████████████████████████████████████████████████████████████████▎                                                                                            | 20/41 [00:43<00:36,  1.73s/it]
INFO:absl:Not sharding. Shape (8, 12, 64)                                                                                                                                                                                 
 66%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                             | 27/41 [00:47<00:09,  1.41it/s]
INFO:absl:Not sharding. Shape (8, 12, 64)                                                                                                                                                                                 
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [01:27<00:00,  2.14s/it]
INFO:absl:Elapse for checkpoint sharding: 1.47 min

Copy link
Collaborator

@RissyRan RissyRan Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I notice the axis 0. More is about axis 1 and others logging with max_logging.log("sharding axis 1").

Could you update it to something like max_logging.log(f"tensor shape with {arr.shape} and sharding axis 1")?

"v_head_dim": 128,
"vocab_size": 129280,
}
# TODO(shuningjin): replace with DeepseekV32Config when available on HF
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

transformers.DeepseekV3Config exists here. By contrast, transformers.DeepseekV32Config does not exist as deepseek3.2 code is still in PR. So I am using the base class PreTrainedConfig for now.

updated comment to be clear

when available on HF -> when available in transformers library



def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace DeepseekV3 weights path and their shape.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we remove this comment? Could we update it if not correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment back for mapping check, with example changed from deepseek3 to deepseek2-lite for practicality.

return v.to(torch.float32).numpy()
# target dtype is bfloat16
elif save_dtype == "bfloat16":
# torch.bfloat16 -> torch.float32 -> np.float32 -> ml_dtypes.bfloat16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we add some explaination here? i.e. Numpy doesn't accept bfloat16 directly, upcast to float32.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated.

return hf_state_dict_numpy[key]
v = hf_state_dict_numpy[key]
# target dtype is float32
if save_dtype == "float32":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we leverage this to avoid potential typo?

class DType(str, Enum):

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated.

del hf_model
max_logging.log("HuggingFace model loaded and converted to NumPy.")

if eager_load_method == "transformers":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some recommendation/comments which option is recommended under which use case? Are those tested optimized boost coming from one strategy?

Copy link
Collaborator Author

@shuningjin shuningjin Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Method 1 (transformers with auto dtype) vs. Method 2 (safetensor with pt)

  1. The optimized boost come from previous Method 0 (transformers, no specified dtype default to f32)-> current method (either 1 or 2 should be same).
  2. For most models, 1 & 2 should have same behavior: the loaded state have same dtype and model structure (e.g., deepseek2-16b)
  3. One exception is gemma3-4b (Method 2 does not have “model” prefix), current gemma3 mapping uses Method 1.
  4. Default to Method 1, for backward compatibility (3).
  5. Method 2 is more general
  • It only needs safetensor, while Method 1 also needs a HuggingFace model.
  • e.g., For deepseek3.2, we have to use Method 2. Cannot use Method 1 (DeepSeek32ForCausalLM is still in PR)
  • e.g., must use this to convert weights omitted by Transformers class (e.g., Multi-Token-Prediction weights layers.61 is not loaded by deepseek-ai/DeepSeek-V3). This unblocks MTP migration, b/483746010.

Copy link
Collaborator Author

@shuningjin shuningjin Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please double check the updated comments:

Appreciate feedback.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Copy link
Collaborator

@hengtaoguo hengtaoguo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@shuningjin shuningjin force-pushed the shuningjin-ckpt-opt3 branch from 91acedc to efdfc8e Compare March 26, 2026 00:21
@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions bot Mar 26, 2026
@AI-Hypercomputer AI-Hypercomputer deleted a comment from github-actions bot Mar 26, 2026
@shuningjin shuningjin force-pushed the shuningjin-ckpt-opt3 branch 2 times, most recently from 6067df3 to 190dee7 Compare March 26, 2026 15:26
@shuningjin shuningjin force-pushed the shuningjin-ckpt-opt3 branch from 190dee7 to 22da4de Compare March 26, 2026 15:48
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change! Just a minor comment to make logging more useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants