Skip to content

Commit 6bae3c4

Browse files
committed
checkpoint utility: shard checkpoint, monitor peak
1 parent e8cbb57 commit 6bae3c4

6 files changed

Lines changed: 223 additions & 117 deletions

File tree

src/MaxText/utils/ckpt_conversion/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml
5050
* `hf_access_token`: Your Hugging Face token.
5151
* `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`.
5252
* `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage.
53-
* `--hf_model_path` (optional): Specifies a local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/2f77e7b5fcc4b580bc2d109525c362f3d9056ec9/src/MaxText/utils/ckpt_conversion/utils/utils.py#L54-L82) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
53+
* `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/2f77e7b5fcc4b580bc2d109525c362f3d9056ec9/src/MaxText/utils/ckpt_conversion/utils/utils.py#L54-L82) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
5454

5555

5656
## MaxText to Hugging Face

src/MaxText/utils/ckpt_conversion/to_huggingface.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
import os
5656
from typing import Sequence
5757
import time
58-
from tqdm import tqdm
5958

6059
from transformers import AutoTokenizer, AutoProcessor
6160

@@ -77,11 +76,10 @@
7776
load_orbax_checkpoint,
7877
detect_and_extract_checkpoint,
7978
HF_IDS,
79+
MemoryMonitorTqdm,
80+
print_peak_memory,
8081
)
8182

82-
os.environ["JAX_PLATFORMS"] = "cpu"
83-
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"
84-
8583

8684
def _get_model_mappings(
8785
model_name: str, scan_layers: bool, hf_config_dict: dict, maxtext_config: pyconfig.HyperParameters
@@ -125,6 +123,9 @@ def main(argv: Sequence[str]) -> None:
125123
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
126124
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
127125

126+
jax.config.update("jax_platforms", "cpu")
127+
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"
128+
128129
# Initialize maxtext config
129130
config = pyconfig.initialize(argv)
130131
assert (
@@ -179,7 +180,7 @@ def main(argv: Sequence[str]) -> None:
179180
start = time.time()
180181
processed_params_list = []
181182

182-
for key in tqdm(filtered_map_keys, total=len(filtered_map_keys)):
183+
for key in MemoryMonitorTqdm(filtered_map_keys, total=len(filtered_map_keys), leave=True):
183184
if isinstance(key, tuple):
184185
# if key is tuple of param names, weight is list of param weights
185186
weight = [maxtext_state_dict[subkey] for subkey in key]
@@ -210,6 +211,7 @@ def main(argv: Sequence[str]) -> None:
210211
max_logging.log(f"✅ MaxText model successfully saved in HuggingFace format at {output_directory}")
211212
max_logging.log(f"Elapse for save: {(time.time() - start) / 60:.2f} min")
212213
max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min")
214+
print_peak_memory()
213215

214216

215217
if __name__ == "__main__":

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 62 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
lazy_load: (bool) If True, uses an on-demand loading strategy to minimize RAM
2929
usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM
3030
Defaults to False.
31+
--hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
32+
If unspecified, we use the default Hugging Face repository ID
33+
(e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `utils/ckpt_conversion/utils`).
34+
This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
3135
3236
Environment Variables:
3337
HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to
@@ -63,66 +67,38 @@
6367
from functools import partial
6468
from typing import Sequence, List, Any, Callable
6569
import numpy as np
66-
import jax
67-
import psutil
68-
from flax.training import train_state
69-
import flax.linen as nn
70+
import absl
71+
7072
from transformers import AutoConfig
71-
from tqdm import tqdm
7273
from huggingface_hub import hf_hub_download, list_repo_files
7374
from safetensors import safe_open
74-
import absl
75-
75+
import jax
76+
import flax.linen as nn
7677
from orbax.checkpoint import type_handlers
78+
7779
from MaxText import max_logging
7880
from MaxText import max_utils
7981
from MaxText import maxtext_utils
8082
from MaxText import pyconfig
8183
from MaxText.common_types import MODEL_MODE_TRAIN
8284
from MaxText.inference_utils import str2bool
8385
from MaxText.layers import models, quantizations
86+
from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint
8487
from MaxText.utils.ckpt_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
85-
from MaxText.utils.ckpt_conversion.utils.utils import apply_hook_fns, HF_IDS, print_ram_usage, get_hf_model, validate_and_filter_param_map_keys
86-
from maxtext.common import checkpointing
88+
from MaxText.utils.ckpt_conversion.utils.utils import (
89+
apply_hook_fns,
90+
HF_IDS,
91+
get_hf_model,
92+
validate_and_filter_param_map_keys,
93+
MemoryMonitorTqdm,
94+
print_ram_usage,
95+
print_peak_memory,
96+
)
8797

88-
jax.config.update("jax_platform_name", "cpu")
8998

9099
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log
91100

92101

93-
class MemoryMonitorTqdm(tqdm):
94-
"""Custom tqdm class that displays memory usage in the progress bar."""
95-
96-
def format_meter(
97-
self,
98-
n,
99-
total,
100-
elapsed,
101-
postfix=None,
102-
**extra_kwargs,
103-
):
104-
"""Override to add memory usage info to the postfix."""
105-
# Get memory info
106-
memory = psutil.virtual_memory()
107-
used_gb = memory.used / (1024**3)
108-
total_gb = memory.total / (1024**3)
109-
memory_percent = memory.percent
110-
111-
# Create memory postfix
112-
memory_info = f"RAM: {used_gb:.1f}/{total_gb:.1f}GB ({memory_percent:.1f}%)"
113-
114-
# Add memory info to postfix
115-
if postfix:
116-
if isinstance(postfix, dict):
117-
postfix["memory"] = memory_info
118-
else:
119-
postfix = f"{postfix}, {memory_info}"
120-
else:
121-
postfix = memory_info
122-
123-
return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs)
124-
125-
126102
class LazyHFLoader:
127103
"""
128104
Loads Hugging Face weights on-demand to minimize RAM usage.
@@ -654,20 +630,12 @@ def _eager_getter(key):
654630
hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config, config.scan_layers, saving_to_hf=False)
655631
max_logging.log("Parameter mappings and hooks obtained.")
656632

657-
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
658-
output_directory,
659-
enable_checkpointing=True,
660-
use_async=False, # Synchronous saving for simplicity in conversion script
661-
save_interval_steps=1, # Save at step 0
662-
use_ocdbt=config.checkpoint_storage_use_ocdbt,
663-
use_zarr3=config.checkpoint_storage_use_zarr3,
664-
)
665-
666633
maxtext_abstract_dict, abstract_params_treedef = get_maxtext_model_info(config)
667634

668635
# Weight transformation
669636
max_logging.log("Starting weight transformation...")
670637
start = time.time()
638+
# Stores MaxText weights: numpy.ndarray
671639
final_mt_weights = [None] * len(maxtext_abstract_dict)
672640

673641
# Preprocess key
@@ -676,7 +644,7 @@ def _eager_getter(key):
676644
for mt_param_key_or_keys in MemoryMonitorTqdm(
677645
filtered_map_keys, desc="Transforming weights", unit="param", leave=True, dynamic_ncols=True
678646
):
679-
if not use_lazy_load and config.scan_layers:
647+
if not use_lazy_load:
680648
max_logging.log(f"maxtext param: {mt_param_key_or_keys}")
681649

682650
hf_source_keys_or_key = param_map_mt_to_hf.get(mt_param_key_or_keys)
@@ -715,37 +683,34 @@ def _eager_getter(key):
715683
jax_weights = jax.tree_util.tree_unflatten(abstract_params_treedef, final_mt_weights)
716684
del final_mt_weights, abstract_params_treedef
717685

718-
# Create TrainState for saving.
719-
final_params_for_state = {"params": jax_weights}
720-
final_save_state = train_state.TrainState(step=0, apply_fn=None, params=final_params_for_state, tx=None, opt_state={})
721-
del final_params_for_state
722-
723686
print_ram_usage("Before saving")
724-
start = time.time()
725-
if checkpoint_manager is not None:
726-
if use_lazy_load:
727-
max_logging.log("Starting checkpoint save (loading weights just-in-time)...")
728-
else:
729-
max_logging.log("Starting checkpoint save...")
730-
731-
if checkpointing.save_checkpoint(checkpoint_manager, 0, final_save_state):
732-
max_logging.log("saved a checkpoint at step 0")
687+
if use_lazy_load:
688+
max_logging.log("Starting checkpoint save (loading weights just-in-time)...")
689+
else:
690+
max_logging.log("Starting checkpoint save...")
733691

734-
# Upon preemption, exit when and only when all ongoing saves are complete.
735-
if checkpoint_manager.reached_preemption(0):
736-
checkpoint_manager.wait_until_finished()
737-
sys.exit()
692+
# Save the converted weights to a MaxText checkpoint.
693+
# If simulated_cpu_devices_count > 1, weights are promoted from NumPy to JAX arrays
694+
# and sharded across virtual devices.
695+
save_weights_to_checkpoint(
696+
output_directory,
697+
jax_weights,
698+
test_args.simulated_cpu_devices_count,
699+
config.checkpoint_storage_use_ocdbt,
700+
config.checkpoint_storage_use_zarr3,
701+
)
738702

739703
print_ram_usage("Program Ends")
740704
max_logging.log(f"Conversion complete. Checkpoint saved to {output_directory}")
741-
max_logging.log(f"Elapse for save: {(time.time() - start) / 60:.2f} min")
742705
max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min")
706+
print_peak_memory()
743707

744708

745709
if __name__ == "__main__":
746710
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
747711
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Suppress TensorFlow logging
748712

713+
# Define local parser
749714
parser = argparse.ArgumentParser()
750715
parser.add_argument(
751716
"--lazy_load_tensors",
@@ -754,13 +719,32 @@ def _eager_getter(key):
754719
default=False,
755720
help="Whether to use lazy loading of HF tensors.",
756721
)
757-
# if not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name]
722+
# If not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name]
758723
parser.add_argument(
759724
"--hf_model_path", type=str, required=False, default="", help="local path to hf model, or custom remote hf repo"
760725
)
761-
local_args, _ = parser.parse_known_args()
762-
model_args = sys.argv
763-
to_remove_args = ["--lazy_load_tensors", "--hf_model_path"]
764-
for a in to_remove_args:
765-
model_args = [s for s in model_args if not s.startswith(a)]
726+
# Determines the logical sharding of the output checkpoint by partitioning
727+
# weights across virtual XLA devices.
728+
# - Even on a single CPU host, JAX can simulate multiple devices (e.g., 16)
729+
# - If set to 1, sharding is skipped.
730+
# - Sharding is preferred. For downstream loading on TPU pods, this helps prevent OOM and speedup.
731+
#
732+
# Example: Embedding Layer shape=(151936, 1024)
733+
# Case 1: simulated_cpu_devices_count=16 (Sharded)
734+
# sharding: NamedShardingMetadata(shape=[16], ...)
735+
# storage: chunk_shape=(9496, 1024) <-- 1/16th of rows per chunk
736+
# Case 2: simulated_cpu_devices_count=1 (Monolith)
737+
# sharding: None
738+
# storage: chunk_shape=(151936, 1024) <-- Full layer in one chunk
739+
parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16)
740+
741+
# Parse local arguments
742+
# Parse known args returns the namespace AND the list of remaining arguments
743+
local_args, remaining_args = parser.parse_known_args()
744+
# Reconstruct model_args (script name + the args MaxText needs)
745+
model_args = [sys.argv[0]] + remaining_args
746+
747+
# Set jax environment
748+
jax.config.update("jax_platforms", "cpu")
749+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"
766750
main(model_args, local_args)

src/MaxText/utils/ckpt_conversion/utils/utils.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import json
2323
from concurrent.futures import ThreadPoolExecutor
2424
from typing import Any
25+
from tqdm import tqdm
26+
import resource
2527

2628
import jax
2729
from jax.experimental import multihost_utils
@@ -753,6 +755,45 @@ def print_ram_usage(stage=""):
753755
)
754756

755757

758+
def print_peak_memory():
759+
# Returns peak usage in Kilobytes on Linux
760+
peak_memory_kb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
761+
max_logging.log(f"Peak Memory: {peak_memory_kb / 1024**2:.2f} GB")
762+
763+
764+
class MemoryMonitorTqdm(tqdm):
765+
"""Custom tqdm class that displays memory usage in the progress bar."""
766+
767+
def format_meter(
768+
self,
769+
n,
770+
total,
771+
elapsed,
772+
postfix=None,
773+
**extra_kwargs,
774+
):
775+
"""Override to add memory usage info to the postfix."""
776+
# Get memory info
777+
memory = psutil.virtual_memory()
778+
used_gb = memory.used / (1024**3)
779+
total_gb = memory.total / (1024**3)
780+
memory_percent = memory.percent
781+
782+
# Create memory postfix
783+
memory_info = f"RAM: {used_gb:.1f}/{total_gb:.1f}GB ({memory_percent:.1f}%)"
784+
785+
# Add memory info to postfix
786+
if postfix:
787+
if isinstance(postfix, dict):
788+
postfix["memory"] = memory_info
789+
else:
790+
postfix = f"{postfix}, {memory_info}"
791+
else:
792+
postfix = memory_info
793+
794+
return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs)
795+
796+
756797
def load_orbax_checkpoint(config) -> dict:
757798
"""Loads a full Orbax checkpoint from disk with unsharded arrays.
758799
@@ -898,7 +939,9 @@ def get_hf_model(model_id: str, token: str):
898939
if model_id in ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]:
899940
from transformers import Qwen3OmniMoeForConditionalGeneration # pylint: disable=import-outside-toplevel
900941

901-
hf_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(model_id, token=token)
942+
model_class = Qwen3OmniMoeForConditionalGeneration.from_pretrained
902943
else:
903-
hf_model = AutoModelForCausalLM.from_pretrained(model_id, token=token)
944+
model_class = AutoModelForCausalLM
945+
946+
hf_model = model_class.from_pretrained(model_id, token=token)
904947
return hf_model

0 commit comments

Comments
 (0)