Add checkpoint resharding script for faster loading#3801
Add checkpoint resharding script for faster loading#3801shuningjin wants to merge 1 commit intomainfrom
Conversation
d549c06 to
fa776ba
Compare
|
🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces a new script reshard_checkpoint.py designed to re-shard MaxText checkpoints on CPU. This utility is highly effective for reducing checkpoint loading times on TPUs, as demonstrated by the significant performance gains reported for DeepSeek-V3. The PR also includes minor robustness improvements and bug fixes in llama_or_mistral_ckpt.py.
🔍 General Feedback
- Performance: The reported 10x reduction in loading time (from 60 min to 6 min) for DeepSeek-V3 is a major improvement for large-scale model training and inference.
- Initialization Timing: A key concern is the timing of JAX initialization in the new script. Setting environment variables like
XLA_FLAGSafter importing JAX-dependent modules may lead to them being ignored if the XLA backend has already been initialized. - Flexibility: Adding a way to specify or preserve the
step_numberwould enhance the utility of the resharding script.
There was a problem hiding this comment.
Supplementing the previous review with the missed comment on JAX initialization timing. Overall, the PR is very valuable for optimizing large model checkpoints.
🔍 General Feedback
- Initialization Timing: Setting
XLA_FLAGSbefore JAX imports ensures the simulated CPU mesh is correctly established.
| - The Orbax checkpoint is streamed from storage directly into the target sharded layout on a simulated CPU mesh, | ||
| and then saved to a new checkpoint. | ||
| - The goal is to pre-shard checkpoints (source) to accelerate loading on TPUs (target) by reducing re-sharding overhead. | ||
| E.g., when target sharding is fsdp=64, checkpoint loading time varies across source sharding (fsdp=64 < fsdp=16 < ep=16) |
There was a problem hiding this comment.
Have you tried fsdp=64 < fsdp=16?
There was a problem hiding this comment.
I only tried fsdp=16. Just removed fsdp=64 from comment for brevity.
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
510b209 to
fa776ba
Compare
0986380 to
3b125fc
Compare
Description
This script re-shards a MaxText checkpoint on CPU. The goal is to pre-shard checkpoints (source) to accelerate loading on TPUs (target) by reducing re-sharding overhead.
FIXES: b/504714612
Introduction
Problem: In checkpoint conversion, we typically shard along the 0th dimension (usually the expert dimension for MoE). Consequently, loading is fast when the target sharding is EP (e.g., a few minutes), but noticeably slow for FSDP (e.g., an hour). This is a major bottleneck because FSDP is our most common use case.
Effectiveness: Our experiments show that pre-sharding a checkpoint to fsdp=16 reduces the loading time of DeepSeek-V3 from 60 minutes to 6 minutes on a v5p-128 cluster targeting fsdp=64. Furthermore, the solution scales efficiently to v7x 1k chips, maintaining a brief 10-minute load time.
Generalizability: While this was built to solve the FSDP loading bottleneck, the solution generalizes to pre-shard checkpoints into other target sharding layout.
Method
The Orbax checkpoint is streamed from storage directly into the target sharded layout on a simulated CPU mesh, and then saved to a new checkpoint.
Key operation trace: maxengine.load_params -> maxtext_utils.setup_decode_state -> checkpointing.load_params_from_path -> orbax.checkpoint.Checkpointer.restore
User Guide
Full details are in docstring.
Key Parameters:
--simulated_cpu_devices_count(defaults to 16). Examples:--simulated_cpu_devices_count=16 ici_fsdp_parallelism=16--simulated_cpu_devices_count=32 ici_fsdp_parallelism=16 ici_expert_parallelism=2weight_dtype: The dtype used to load and save the checkpoint. Highly recommend usingweight_dtype=bfloat16.Memory Requirements:
weight_dtype=bfloat16).Tests
deepseek3-671b with mtp
Full test details in b/504714612 (comment3, comment8)
deepseek2-16b
Reshard:
Inspect structure:
forward_pass_logit_checker, load with target sharding fsdp=16:
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.