-
Notifications
You must be signed in to change notification settings - Fork 458
checkpoint utility: shard checkpoint, monitor peak #2974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
ea2ff56 to
bc6f31d
Compare
dd5bad1 to
30418aa
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need to maintain standalone scripts? I thought most of models are migrated to bi-directional conversion utility.
| ) | ||
| # Used to convert numpy weights to sharded jax arrays across simulated cpu devices | ||
| # If count=1, do not shard | ||
| parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some explanation about this flag, and why default is set to 16 here? If we run a conversion on a single-cpu device but set this number higher than 1, what will happen? Thanks for the clarification!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for review! I have added detailed comment to explain this flag. To answer your question,
why default is set to 16
- Most of the previous conversion script use shard=16. See: llama & mixtral, deepseek & kimi, as well as gpt-oss, qwen3-moe.
- Since this has been used and tested for a long term, I think might be good to follow.
- In particular, all MoEs and largest models (deepseek 671b, kimi 1T) has been converted this way.
If we run a conversion on a single-cpu device but set this number higher than 1, what will happen
By setting these flags, JAX can simulate multiple devices, even on a single CPU host.
jax.config.update("jax_platforms", "cpu")
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"
In particular, len(jax.devices()) will be equal to this count, and the simulated devices are used to create mesh to shard weights on.
# Example: Embedding Layer shape=(151936, 1024)
# Case 1: simulated_cpu_devices_count=16 (Sharded)
# sharding: NamedShardingMetadata(shape=[16], ...)
# storage: chunk_shape=(9496, 1024) <-- 1/16th of rows per chunk
# Case 2: simulated_cpu_devices_count=1 (Monolith)
# sharding: None
# storage: chunk_shape=(151936, 1024) <-- Full layer in one chunk
3744bc0 to
6bae3c4
Compare
In most cases, we don't need to maintain standalone scripts. In this case, however,
|
00d68d9 to
d469717
Compare
04c3e3d to
4e68630
Compare
Description
Fix: b/477648456
Main change:
to_maxtext.pyadd option to shard weights before saving orbax checkpoint.--simulated_cpu_devices_count, default to 16. That is, shard ckpt across 16 simulated cpu array. Note: This default value follows previous scripts; see below.--simulated_cpu_devices_count=1, skip shardingsave_weights_to_checkpointfromMaxText.utils.ckpt_scripts.llama_or_mistral_ckptsave_weights_to_checkpoint: add comment, use pop() rather than pop(0), log time for shard and saveWhy?
Additional change:
MemoryMonitorTqdmfrom to_maxtext to utilsTests
model:
qwen3-0.6bauxiliary script to check checkpoint sharding
check_orbax.py: https://paste.googleplex.com/49619026922700801 eager mode, cpu, simulated_cpu_devices_count=16 (default)
log: https://paste.googleplex.com/6454418499305472
gs://runner-maxtext-logs/2026-01-28-01-19/0/items
INFO:absl:Peak Memory: 6.50 GB
check sharding
https://paste.googleplex.com/4954491986247680
2 eager mode, cpu, simulated_cpu_devices_count=1 (no shard)
log: https://paste.googleplex.com/5084079739502592
gs://runner-maxtext-logs/2026-01-28-01-23/0/items
INFO:absl:Peak Memory: 6.72 GB
check sharding
https://paste.googleplex.com/4528057568329728
3 lazy mode, cpu, simulated_cpu_devices_count=16 (default)
log: https://paste.googleplex.com/6152565647605760
gs://runner-maxtext-logs/2026-01-28-01-26/0/items
INFO:absl:Peak Memory: 3.53 GB
check sharding
https://paste.googleplex.com/6147058778112000
Sanity check: TPU v5p-8
4 lazy mode, cpu, simulated_cpu_devices_count=1 (no shard)
log: https://paste.googleplex.com/5756578252849152
gs://runner-maxtext-logs/2026-01-28-01-29/0/items
INFO:absl:Peak Memory: 2.95 GB
check sharding
https://paste.googleplex.com/6370607933554688
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.