checkpoint utility: optimize to_maxtext, add deepseek#3184
checkpoint utility: optimize to_maxtext, add deepseek#3184shuningjin wants to merge 1 commit intomainfrom
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
0798438 to
5702326
Compare
|
🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces significant improvements and optimizations to the checkpoint conversion process in MaxText, specifically focusing on DeepSeek model support (V2-16B, V3-671B, and V3.2-671B). The implementation of LazyHFLoader and the adoption of dtype="auto" in Hugging Face loading are excellent additions that substantially reduce memory overhead, making the conversion of extremely large models more feasible.
🔍 General Feedback
- Efficiency: The shift towards memory-efficient loading strategies is a major highlight. Using
safetensorson-demand avoids redundant memory consumption during theto_maxtextconversion. - Support: Comprehensive support for DeepSeek's MLA architecture and MoE experts is well-integrated into both
hf_shape.pyandparam_mapping.py. - Maintainability: The refactoring of
forward_pass_logit_checker.pyand the grouping of reshape hooks inparam_mapping.pysignificantly improve code clarity and ease of future extension.
|
|
||
| hf_model = model_class.from_pretrained(model_id, token=token, revision=revision) | ||
| return hf_model | ||
| hf_model = model_class.from_pretrained(model_id, token=token, revision=revision, dtype=dtype) |
There was a problem hiding this comment.
| hf_model = model_class.from_pretrained(model_id, token=token, revision=revision, dtype=dtype) | |
| hf_model = model_class.from_pretrained(model_id, token=token, revision=revision, torch_dtype=dtype) |
There was a problem hiding this comment.
We should stick to "dtype". I am seeing deprecation warning when using "torch_dtype" in model loading.
- Added a comment
# Note: transformers deprecates `torch_dtype` in favor of standard `dtype` in model loading
- test
! pip install transformers==4.57.3
import transformers
print(transformers.__version__)
from transformers import AutoModelForCausalLM
AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6b", torch_dtype="bfloat16")
4.57.3
warnings.warn(
`torch_dtype` is deprecated! Use `dtype` instead!
-
This is with
transformers==4.57.3, which is the version of maxtext image, confirmed by
-
Also see Update deprecated torch_dtype argument huggingface/peft#2835
| def reshape_audio_attn_out(input_tensor, target_shape): | ||
| """Reshape audio attention output projection. | ||
| F | ||
| HF: (hidden_size, hidden_size) | ||
| MaxText: (num_heads, head_dim, hidden_size) | ||
|
|
||
| HF: (hidden_size, hidden_size) | ||
| MaxText: (num_heads, head_dim, hidden_size) | ||
| """ |
There was a problem hiding this comment.
🟢 There is a typo (extra F) on a single line in the docstring for reshape_audio_attn_out.
| def reshape_audio_attn_out(input_tensor, target_shape): | |
| """Reshape audio attention output projection. | |
| F | |
| HF: (hidden_size, hidden_size) | |
| MaxText: (num_heads, head_dim, hidden_size) | |
| HF: (hidden_size, hidden_size) | |
| MaxText: (num_heads, head_dim, hidden_size) | |
| """ | |
| def reshape_audio_attn_out(input_tensor, target_shape): | |
| \"\"\"Reshape audio attention output projection. | |
| HF: (hidden_size, hidden_size) | |
| MaxText: (num_heads, head_dim, hidden_size) | |
| \"\"\" |
There was a problem hiding this comment.
This was removed previously.
RissyRan
left a comment
There was a problem hiding this comment.
LGTM at high level! A few minor comments.
| | **DeepSeek3** | 671B | - | - | √ | - | | ||
| | **DeepSeek2** | 16B | √ | √ | √ | √ | | ||
| | **DeepSeek3** | 671B | √ | √ | √ | √ | | ||
| | **DeepSeek3.2** | 671B | √ | √ | - | - | |
There was a problem hiding this comment.
is Orbax (scan/unscan) → HF a follow-up PR? This could be helpful for the RL if needed.
There was a problem hiding this comment.
For to_huggingface
- the current deepseek3.2 param/hook mapping should work out of box (expect no code change)
- currently, have error in saving
config.json, asDeepseekV32ForCausalLMis missing from transformers library (still in PR).
Next step:
- Verify to huggingface conversion after PR 41251 is merged in transformers. Also logit check might be easier.
- Created b/496411531 to follow-up on the verification
| # Sharding axis 0: Omit log for brevity per the summary log above. | ||
| return jax.device_put(arr, device=s1) | ||
| elif len(arr.shape) > 1 and arr.shape[1] % device_count == 0: | ||
| max_logging.log("sharding axis 1") |
There was a problem hiding this comment.
Are these logs actually helpful? Since they don't specify which tensor is being sharded and only show '0, 1, etc', they add a lot of noise. Should we downgrade them to DEBUG with tensor shape?
There was a problem hiding this comment.
TLDR: I added tensor shape logging. I prefer to keep the current logging level as it is lightweight.
The current logging has omitted "Axis 0 sharding" (which is the most cases); previously it log every tensor. As a result, the log for sharding is relatively lightweight. Examples:
deepseek3.2 (most other models look like this)
INFO:absl:Note: Axis 0 sharding is the default and will not be logged individually.
100%|██████████████████████████████| 42/42 [30:13<00:00, 43.18s/it]
INFO:absl:Elapse for checkpoint sharding: 30.24 min
gpt-oss-20b scanned
INFO:absl:Note: Axis 0 sharding is the default and will not be logged individually.
2%|████▍ | 1/41 [00:00<00:25, 1.57it/s]
INFO:absl:Not sharding. Shape (8, 12, 64)
20%|███████████████████████████████████▌ | 8/41 [00:04<00:20, 1.63it/s]
INFO:absl:Not sharding. Shape (8, 12, 64)
49%|████████████████████████████████████████████████████████████████████████████████████████▎ | 20/41 [00:43<00:36, 1.73s/it]
INFO:absl:Not sharding. Shape (8, 12, 64)
66%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 27/41 [00:47<00:09, 1.41it/s]
INFO:absl:Not sharding. Shape (8, 12, 64)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [01:27<00:00, 2.14s/it]
INFO:absl:Elapse for checkpoint sharding: 1.47 min
There was a problem hiding this comment.
Yeah, I notice the axis 0. More is about axis 1 and others logging with max_logging.log("sharding axis 1").
Could you update it to something like max_logging.log(f"tensor shape with {arr.shape} and sharding axis 1")?
| "v_head_dim": 128, | ||
| "vocab_size": 129280, | ||
| } | ||
| # TODO(shuningjin): replace with DeepseekV32Config when available on HF |
There was a problem hiding this comment.
Could you explain a little bit? https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/config.json
There was a problem hiding this comment.
transformers.DeepseekV3Config exists here. By contrast, transformers.DeepseekV32Config does not exist as deepseek3.2 code is still in PR. So I am using the base class PreTrainedConfig for now.
updated comment to be clear
when available on HF -> when available in transformers library
|
|
||
|
|
||
| def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): | ||
| """Returns mapping between HuggingFace DeepseekV3 weights path and their shape. |
There was a problem hiding this comment.
Why we remove this comment? Could we update it if not correct?
There was a problem hiding this comment.
Added comment back for mapping check, with example changed from deepseek3 to deepseek2-lite for practicality.
| return v.to(torch.float32).numpy() | ||
| # target dtype is bfloat16 | ||
| elif save_dtype == "bfloat16": | ||
| # torch.bfloat16 -> torch.float32 -> np.float32 -> ml_dtypes.bfloat16 |
There was a problem hiding this comment.
Shall we add some explaination here? i.e. Numpy doesn't accept bfloat16 directly, upcast to float32.
There was a problem hiding this comment.
Thanks, updated.
| return hf_state_dict_numpy[key] | ||
| v = hf_state_dict_numpy[key] | ||
| # target dtype is float32 | ||
| if save_dtype == "float32": |
There was a problem hiding this comment.
Shall we leverage this to avoid potential typo?
maxtext/src/maxtext/configs/types.py
Line 62 in de51021
There was a problem hiding this comment.
Thanks, updated.
| del hf_model | ||
| max_logging.log("HuggingFace model loaded and converted to NumPy.") | ||
|
|
||
| if eager_load_method == "transformers": |
There was a problem hiding this comment.
Could you add some recommendation/comments which option is recommended under which use case? Are those tested optimized boost coming from one strategy?
There was a problem hiding this comment.
Method 1 (transformers with auto dtype) vs. Method 2 (safetensor with pt)
- The optimized boost come from previous Method 0 (transformers, no specified dtype default to f32)-> current method (either 1 or 2 should be same).
- For most models, 1 & 2 should have same behavior: the loaded state have same dtype and model structure (e.g., deepseek2-16b)
- One exception is gemma3-4b (Method 2 does not have “model” prefix), current gemma3 mapping uses Method 1.
- Default to Method 1, for backward compatibility (3).
- Method 2 is more general
- It only needs safetensor, while Method 1 also needs a HuggingFace model.
- e.g., For deepseek3.2, we have to use Method 2. Cannot use Method 1 (DeepSeek32ForCausalLM is still in PR)
- e.g., must use this to convert weights omitted by Transformers class (e.g., Multi-Token-Prediction weights layers.61 is not loaded by deepseek-ai/DeepSeek-V3). This unblocks MTP migration, b/483746010.
There was a problem hiding this comment.
Please double check the updated comments:
Appreciate feedback.
91acedc to
efdfc8e
Compare
6067df3 to
190dee7
Compare
190dee7 to
22da4de
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the change! Just a minor comment to make logging more useful.
Description
Optimize
to_maxtextloading and savingto_maxtexteager loading and saving pipelines.deepseek2-16b,deepseek3-671b,deepseek3.2-671b)What Changed
Previously, eager load defaulted to
transformers_class.from_pretrained(...), which loaded, converted, and saved checkpoints infloat32. (If using latest transformers version, e2e run would have error.)This PR introduces two optimized loading methods and adds the ability to save in
bfloat16:transformers_class.from_pretrained(..., dtype="auto")to load the original tensor type.safetensors.safe_open(..., framework="pt")to load natively from safetensors. Similar to Method 1, this can either process remote repo or local path.bfloat16as the recommended save option (withfloat32retained as a backup).Why It Matters (Impact & Benefits)
gpt-oss-120b: 1009.72 GB -> 511.47 GBfloat32casting and NumPy bottlenecks.gpt-oss-120b: 78 min -> 1sdeepseek3-671b: 7.5 hr -> 4 mingpt-oss-120b: 134.86 min -> 96.22 min.deepseek3-671bpreviously OOM'd on 3.7TB RAM; it is now feasible with a peak of 2854.90 GB and a total conversion time of ~9.5 hours.gpt-oss-120bdropped from 100.17 GiB to 74.23 GiB).safetensors) allows us todeepseek-ai/DeepSeek-V3.2).layers.61is not loaded bydeepseek-ai/DeepSeek-V3)Other changes
to_maxtext: ReuseHF_MODEL_CONFIGSrather thantransformers.AutoConfig. This accommodates model without full HuggingFace code support (e.g.,deepseek3.2). This also aligns with howto_huggingfaceuses config.to_huggingface: Initially, maxtext weights are loaded viaset_decode_state, which usesconfig.weight_dtype. It was subsequently changed to orbax restore, which loads the weight as is. To control save dtype, we now explicitly cast it toconfig.weight_typeinutils._process.Tests
Test details in doc.
1. Performance (
gpt-oss-120b)hf-bf16to maxtext scanned.bfloat16save.2. Functionality (
qwen3-0.6b)hf-bf16to maxtext scanned.{lazy, method1, method2} x {bfloat16, float32}.3. Scalability (
deepseek3-671b){to_maxtext} x {scanned}.to_maxtextfor this model class. (Previously, onlyto_huggingfacewas feasible due to OOM constraints).4. New DeepSeek Mappings
deepseek2-16b:{to_maxtext, to_huggingface} x {scanned, unscanned}deepseek3-671b:{to_maxtext} x {unscanned}deepseek3.2:{to_maxtext} x {scanned, unscanned}. Noteto_huggingfaceis not enabled as DeepSeek32ForCausalLM is not supported yet.Examples:
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.