Skip to content

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Jan 20, 2026

Description

Fix: b/477648456

Main change: to_maxtext.py add option to shard weights before saving orbax checkpoint.

  • control by --simulated_cpu_devices_count, default to 16. That is, shard ckpt across 16 simulated cpu array. Note: This default value follows previous scripts; see below.
    • If set --simulated_cpu_devices_count=1, skip sharding
  • reuse save_weights_to_checkpoint from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt
    • refine save_weights_to_checkpoint: add comment, use pop() rather than pop(0), log time for shard and save

Why?

  • In most of previous conversions scripts, the weights are sharded before saving (here), mostly with simulated_cpu_devices_count=16 and shard 0th dim
  • From previous investigation, the main issue of unsharded ckpt is out of RAM on TPU and slow loading speed.
    • b/302192179#comment22: llama2-70b OOM on v5e with 197GB RAM
    • b/326133855#comment5
    • Following Ansiha ckpt perf cpu #466, from 2024-03.
  • In conclusion, checkpoint sharding has been used and tested for a long term. Particularly, all MoEs and largest models (deepseek-671b, kimi-1T) has been converted this way. Therefore, it is a good practice to follow.

Additional change:

  • print peak memory
  • move MemoryMonitorTqdm from to_maxtext to utils

Tests

model: qwen3-0.6b

auxiliary script to check checkpoint sharding check_orbax.py: https://paste.googleplex.com/4961902692270080

1 eager mode, cpu, simulated_cpu_devices_count=16 (default)

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-0.6b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--simulated_cpu_devices_count=16

log: https://paste.googleplex.com/6454418499305472
gs://runner-maxtext-logs/2026-01-28-01-19/0/items
INFO:absl:Peak Memory: 6.50 GB

INFO:absl:shard weights across 16 devices
  0%|                                                   | 0/13 [00:00<?, ?it/s]
INFO:absl:sharding axis 0
INFO:absl:Elapse for checkpoint sharding: 0.16 min

check sharding

python check_orbax.py gs://runner-maxtext-logs/2026-01-28-01-19/0/items

https://paste.googleplex.com/4954491986247680

ArrayMetadata :  name=params.params.token_embedder.embedding,  directory=gs://runner-maxtext-logs/2026-01-28-01-19/0/items,  shape=(151936, 1024),  sharding=NamedShardingMetadata(shape=[16], axis_names=['checkpoint_sharding_axis'], axis_types=(Auto,), partition_spec=('checkpoint_sharding_axis',)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0), DeviceMetadata(id=1), DeviceMetadata(id=2), DeviceMetadata(id=3), DeviceMetadata(id=4), DeviceMetadata(id=5), DeviceMetadata(id=6), DeviceMetadata(id=7), DeviceMetadata(id=8), DeviceMetadata(id=9), DeviceMetadata(id=10), DeviceMetadata(id=11), DeviceMetadata(id=12), DeviceMetadata(id=13), DeviceMetadata(id=14), DeviceMetadata(id=15)]),  dtype=float32,  storage=StorageMetadata(chunk_shape=(9496, 1024), write_shape=(9496, 1024)),

2 eager mode, cpu, simulated_cpu_devices_count=1 (no shard)

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-0.6b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--simulated_cpu_devices_count=1

log: https://paste.googleplex.com/5084079739502592
gs://runner-maxtext-logs/2026-01-28-01-23/0/items
INFO:absl:Peak Memory: 6.72 GB

check sharding

python check_orbax.py gs://runner-maxtext-logs/2026-01-28-01-23/0/items

https://paste.googleplex.com/4528057568329728

ArrayMetadata :  name=params.params.token_embedder.embedding,  directory=gs://runner-maxtext-logs/2026-01-28-01-23/0/items,  shape=(151936, 1024),  sharding=None,  dtype=float32,  storage=StorageMetadata(chunk_shape=(151936, 1024), write_shape=None),

3 lazy mode, cpu, simulated_cpu_devices_count=16 (default)

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-0.6b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--lazy_load_tensors=true --simulated_cpu_devices_count=16

log: https://paste.googleplex.com/6152565647605760
gs://runner-maxtext-logs/2026-01-28-01-26/0/items
INFO:absl:Peak Memory: 3.53 GB

check sharding

python check_orbax.py gs://runner-maxtext-logs/2026-01-28-01-26/0/items

https://paste.googleplex.com/6147058778112000

ArrayMetadata :  name=params.params.token_embedder.embedding,  directory=gs://runner-maxtext-logs/2026-01-28-01-26/0/items,  shape=(151936, 1024),  sharding=NamedShardingMetadata(shape=[16], axis_names=['checkpoint_sharding_axis'], axis_types=(Auto,), partition_spec=('checkpoint_sharding_axis',)) device_mesh=DeviceMetadataMesh(mesh=[DeviceMetadata(id=0), DeviceMetadata(id=1), DeviceMetadata(id=2), DeviceMetadata(id=3), DeviceMetadata(id=4), DeviceMetadata(id=5), DeviceMetadata(id=6), DeviceMetadata(id=7), DeviceMetadata(id=8), DeviceMetadata(id=9), DeviceMetadata(id=10), DeviceMetadata(id=11), DeviceMetadata(id=12), DeviceMetadata(id=13), DeviceMetadata(id=14), DeviceMetadata(id=15)]),  dtype=bfloat16,  storage=StorageMetadata(chunk_shape=(9496, 1024), write_shape=(9496, 1024)),

Sanity check: TPU v5p-8

4 lazy mode, cpu, simulated_cpu_devices_count=1 (no shard)

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-0.6b scan_layers=true \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True \
attention=dot_product \
--lazy_load_tensors=true --simulated_cpu_devices_count=1

log: https://paste.googleplex.com/5756578252849152
gs://runner-maxtext-logs/2026-01-28-01-29/0/items
INFO:absl:Peak Memory: 2.95 GB

check sharding

python check_orbax.py gs://runner-maxtext-logs/2026-01-28-01-29/0/items

https://paste.googleplex.com/6370607933554688

ArrayMetadata :  name=params.params.token_embedder.embedding,  directory=gs://runner-maxtext-logs/2026-01-28-01-29/0/items,  shape=(151936, 1024),  sharding=None,  dtype=float32,  storage=StorageMetadata(chunk_shape=(151936, 1024), write_shape=None),

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@shuningjin shuningjin changed the title checkpoint utility: save shard checkpoint and improve mem monitor checkpoint utility: shard checkpoint, monitor peak Jan 20, 2026
@codecov
Copy link

codecov bot commented Jan 20, 2026

Codecov Report

❌ Patch coverage is 0% with 41 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/utils/ckpt_conversion/utils/utils.py 0.00% 21 Missing ⚠️
src/MaxText/utils/ckpt_conversion/to_maxtext.py 0.00% 16 Missing ⚠️
...rc/MaxText/utils/ckpt_conversion/to_huggingface.py 0.00% 4 Missing ⚠️

📢 Thoughts on this report? Let us know!

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to maintain standalone scripts? I thought most of models are migrated to bi-directional conversion utility.

)
# Used to convert numpy weights to sharded jax arrays across simulated cpu devices
# If count=1, do not shard
parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some explanation about this flag, and why default is set to 16 here? If we run a conversion on a single-cpu device but set this number higher than 1, what will happen? Thanks for the clarification!

Copy link
Collaborator Author

@shuningjin shuningjin Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for review! I have added detailed comment to explain this flag. To answer your question,

why default is set to 16

  • Most of the previous conversion script use shard=16. See: llama & mixtral, deepseek & kimi, as well as gpt-oss, qwen3-moe.
  • Since this has been used and tested for a long term, I think might be good to follow.
  • In particular, all MoEs and largest models (deepseek 671b, kimi 1T) has been converted this way.

If we run a conversion on a single-cpu device but set this number higher than 1, what will happen

By setting these flags, JAX can simulate multiple devices, even on a single CPU host.

  jax.config.update("jax_platforms", "cpu")
  os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"

In particular, len(jax.devices()) will be equal to this count, and the simulated devices are used to create mesh to shard weights on.

  # Example: Embedding Layer shape=(151936, 1024)
  # Case 1: simulated_cpu_devices_count=16 (Sharded)
  #   sharding: NamedShardingMetadata(shape=[16], ...)
  #   storage:  chunk_shape=(9496, 1024)  <-- 1/16th of rows per chunk
  # Case 2: simulated_cpu_devices_count=1 (Monolith)
  #   sharding: None
  #   storage:  chunk_shape=(151936, 1024) <-- Full layer in one chunk

@shuningjin shuningjin force-pushed the shuningjin-ckpt-opt2 branch 2 times, most recently from 3744bc0 to 6bae3c4 Compare January 28, 2026 01:03
@shuningjin
Copy link
Collaborator Author

shuningjin commented Jan 28, 2026

Do we still need to maintain standalone scripts? I thought most of models are migrated to bi-directional conversion utility.

In most cases, we don't need to maintain standalone scripts. In this case, however,

  • I am reusing save_weights_to_checkpoint from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt
  • I am adding logging metrics to MaxText.utils.ckpt_scripts.convert_gpt_oss_ckpt.py , as I want to subsequently compare the script/tool performance

@shuningjin shuningjin force-pushed the shuningjin-ckpt-opt2 branch from 04c3e3d to 4e68630 Compare January 28, 2026 23:53
@copybara-service copybara-service bot merged commit 5b481be into main Jan 29, 2026
24 of 25 checks passed
@copybara-service copybara-service bot deleted the shuningjin-ckpt-opt2 branch January 29, 2026 01:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants