Skip to content

Commit 30418aa

Browse files
committed
checkpoint utility: shard checkpoint, monitor peak
1 parent d4a259d commit 30418aa

6 files changed

Lines changed: 190 additions & 113 deletions

File tree

src/MaxText/utils/ckpt_conversion/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml
5050
* `hf_access_token`: Your Hugging Face token.
5151
* `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`.
5252
* `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage.
53-
* `--hf_model_path` (optional): Specifies a local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/2f77e7b5fcc4b580bc2d109525c362f3d9056ec9/src/MaxText/utils/ckpt_conversion/utils/utils.py#L54-L82) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
53+
* `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/2f77e7b5fcc4b580bc2d109525c362f3d9056ec9/src/MaxText/utils/ckpt_conversion/utils/utils.py#L54-L82) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
5454

5555

5656
## MaxText to Hugging Face

src/MaxText/utils/ckpt_conversion/to_huggingface.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
import os
5656
from typing import Sequence
5757
import time
58-
from tqdm import tqdm
5958

6059
from transformers import AutoTokenizer, AutoProcessor
6160

@@ -77,6 +76,8 @@
7776
load_orbax_checkpoint,
7877
detect_and_extract_checkpoint,
7978
HF_IDS,
79+
MemoryMonitorTqdm,
80+
print_peak_memory,
8081
)
8182

8283
os.environ["JAX_PLATFORMS"] = "cpu"
@@ -179,7 +180,7 @@ def main(argv: Sequence[str]) -> None:
179180
start = time.time()
180181
processed_params_list = []
181182

182-
for key in tqdm(filtered_map_keys, total=len(filtered_map_keys)):
183+
for key in MemoryMonitorTqdm(filtered_map_keys, total=len(filtered_map_keys), leave=True):
183184
if isinstance(key, tuple):
184185
# if key is tuple of param names, weight is list of param weights
185186
weight = [maxtext_state_dict[subkey] for subkey in key]
@@ -210,6 +211,7 @@ def main(argv: Sequence[str]) -> None:
210211
max_logging.log(f"✅ MaxText model successfully saved in HuggingFace format at {output_directory}")
211212
max_logging.log(f"Elapse for save: {(time.time() - start) / 60:.2f} min")
212213
max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min")
214+
print_peak_memory()
213215

214216

215217
if __name__ == "__main__":

src/MaxText/utils/ckpt_conversion/to_maxtext.py

Lines changed: 41 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828
--lazy_load_tensors: (bool) If True, uses an on-demand loading strategy to minimize RAM
2929
usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM
3030
Defaults to False.
31-
--hf_model_path: (Optional) Specify a local HF path, rather than the default repo `HF_IDS[model_name]`.
32-
Useful for locally dequantized HF model like GPT-OSS or DeepSeek.
31+
--hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
32+
If unspecified, we use the default Hugging Face repository ID
33+
(e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `utils/ckpt_conversion/utils`).
34+
This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
3335
3436
Environment Variables:
3537
HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to
@@ -65,67 +67,38 @@
6567
from functools import partial
6668
from typing import Sequence, List, Any, Callable
6769
import numpy as np
68-
import jax
69-
import psutil
70-
from flax.training import train_state
71-
import flax.linen as nn
70+
import absl
71+
7272
from transformers import AutoConfig
73-
from tqdm import tqdm
7473
from huggingface_hub import hf_hub_download, list_repo_files
7574
from safetensors import safe_open
76-
import absl
77-
75+
import jax
76+
import flax.linen as nn
7877
from orbax.checkpoint import type_handlers
79-
from MaxText import checkpointing
78+
8079
from MaxText import max_logging
8180
from MaxText import max_utils
8281
from MaxText import maxtext_utils
8382
from MaxText import pyconfig
8483
from MaxText.common_types import MODEL_MODE_TRAIN
8584
from MaxText.inference_utils import str2bool
8685
from MaxText.layers import models, quantizations
87-
from MaxText.checkpointing import save_checkpoint
86+
from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint
8887
from MaxText.utils.ckpt_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
89-
from MaxText.utils.ckpt_conversion.utils.utils import apply_hook_fns, HF_IDS, print_ram_usage, get_hf_model, validate_and_filter_param_map_keys
88+
from MaxText.utils.ckpt_conversion.utils.utils import (
89+
apply_hook_fns,
90+
HF_IDS,
91+
get_hf_model,
92+
validate_and_filter_param_map_keys,
93+
MemoryMonitorTqdm,
94+
print_ram_usage,
95+
print_peak_memory,
96+
)
9097

91-
jax.config.update("jax_platform_name", "cpu")
9298

9399
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log
94100

95101

96-
class MemoryMonitorTqdm(tqdm):
97-
"""Custom tqdm class that displays memory usage in the progress bar."""
98-
99-
def format_meter(
100-
self,
101-
n,
102-
total,
103-
elapsed,
104-
postfix=None,
105-
**extra_kwargs,
106-
):
107-
"""Override to add memory usage info to the postfix."""
108-
# Get memory info
109-
memory = psutil.virtual_memory()
110-
used_gb = memory.used / (1024**3)
111-
total_gb = memory.total / (1024**3)
112-
memory_percent = memory.percent
113-
114-
# Create memory postfix
115-
memory_info = f"RAM: {used_gb:.1f}/{total_gb:.1f}GB ({memory_percent:.1f}%)"
116-
117-
# Add memory info to postfix
118-
if postfix:
119-
if isinstance(postfix, dict):
120-
postfix["memory"] = memory_info
121-
else:
122-
postfix = f"{postfix}, {memory_info}"
123-
else:
124-
postfix = memory_info
125-
126-
return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs)
127-
128-
129102
class LazyHFLoader:
130103
"""
131104
Loads Hugging Face weights on-demand to minimize RAM usage.
@@ -657,20 +630,12 @@ def _eager_getter(key):
657630
hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config, config.scan_layers, saving_to_hf=False)
658631
max_logging.log("Parameter mappings and hooks obtained.")
659632

660-
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
661-
output_directory,
662-
enable_checkpointing=True,
663-
use_async=False, # Synchronous saving for simplicity in conversion script
664-
save_interval_steps=1, # Save at step 0
665-
use_ocdbt=config.checkpoint_storage_use_ocdbt,
666-
use_zarr3=config.checkpoint_storage_use_zarr3,
667-
)
668-
669633
maxtext_abstract_dict, abstract_params_treedef = get_maxtext_model_info(config)
670634

671635
# Weight transformation
672636
max_logging.log("Starting weight transformation...")
673637
start = time.time()
638+
# Stores MaxText weights: numpy.ndarray
674639
final_mt_weights = [None] * len(maxtext_abstract_dict)
675640

676641
# Preprocess key
@@ -679,7 +644,7 @@ def _eager_getter(key):
679644
for mt_param_key_or_keys in MemoryMonitorTqdm(
680645
filtered_map_keys, desc="Transforming weights", unit="param", leave=True, dynamic_ncols=True
681646
):
682-
if not use_lazy_load and config.scan_layers:
647+
if not use_lazy_load:
683648
max_logging.log(f"maxtext param: {mt_param_key_or_keys}")
684649

685650
hf_source_keys_or_key = param_map_mt_to_hf.get(mt_param_key_or_keys)
@@ -718,31 +683,24 @@ def _eager_getter(key):
718683
jax_weights = jax.tree_util.tree_unflatten(abstract_params_treedef, final_mt_weights)
719684
del final_mt_weights, abstract_params_treedef
720685

721-
# Create TrainState for saving.
722-
final_params_for_state = {"params": jax_weights}
723-
final_save_state = train_state.TrainState(step=0, apply_fn=None, params=final_params_for_state, tx=None, opt_state={})
724-
del final_params_for_state
725-
726686
print_ram_usage("Before saving")
727-
start = time.time()
728-
if checkpoint_manager is not None:
729-
if use_lazy_load:
730-
max_logging.log("Starting checkpoint save (loading weights just-in-time)...")
731-
else:
732-
max_logging.log("Starting checkpoint save...")
733-
734-
if save_checkpoint(checkpoint_manager, 0, final_save_state):
735-
max_logging.log("saved a checkpoint at step 0")
687+
if use_lazy_load:
688+
max_logging.log("Starting checkpoint save (loading weights just-in-time)...")
689+
else:
690+
max_logging.log("Starting checkpoint save...")
736691

737-
# Upon preemption, exit when and only when all ongoing saves are complete.
738-
if checkpoint_manager.reached_preemption(0):
739-
checkpoint_manager.wait_until_finished()
740-
sys.exit()
692+
save_weights_to_checkpoint(
693+
output_directory,
694+
jax_weights,
695+
test_args.simulated_cpu_devices_count,
696+
config.checkpoint_storage_use_ocdbt,
697+
config.checkpoint_storage_use_zarr3,
698+
)
741699

742700
print_ram_usage("Program Ends")
743701
max_logging.log(f"Conversion complete. Checkpoint saved to {output_directory}")
744-
max_logging.log(f"Elapse for save: {(time.time() - start) / 60:.2f} min")
745702
max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min")
703+
print_peak_memory()
746704

747705

748706
if __name__ == "__main__":
@@ -757,13 +715,19 @@ def _eager_getter(key):
757715
default=False,
758716
help="Whether to use lazy loading of HF tensors.",
759717
)
760-
# if not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name]
718+
# If not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name]
761719
parser.add_argument(
762720
"--hf_model_path", type=str, required=False, default="", help="local path to hf model, or custom remote hf repo"
763721
)
722+
# Used to convert numpy weights to sharded jax arrays across simulated cpu devices
723+
# If count=1, do not shard
724+
parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16)
764725
local_args, _ = parser.parse_known_args()
765726
model_args = sys.argv
766-
to_remove_args = ["--lazy_load_tensors", "--hf_model_path"]
727+
to_remove_args = ["--lazy_load_tensors", "--hf_model_path", "--simulated_cpu_devices_count"]
767728
for a in to_remove_args:
768729
model_args = [s for s in model_args if not s.startswith(a)]
730+
731+
os.environ["JAX_PLATFORMS"] = "cpu"
732+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"
769733
main(model_args, local_args)

src/MaxText/utils/ckpt_conversion/utils/utils.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import json
2323
from concurrent.futures import ThreadPoolExecutor
2424
from typing import Any
25+
from tqdm import tqdm
26+
import resource
2527

2628
import jax
2729
from jax.experimental import multihost_utils
@@ -753,6 +755,45 @@ def print_ram_usage(stage=""):
753755
)
754756

755757

758+
def print_peak_memory():
759+
# Returns peak usage in Kilobytes on Linux
760+
peak_memory_kb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
761+
max_logging.log(f"Peak Memory: {peak_memory_kb / 1024**2:.2f} GB")
762+
763+
764+
class MemoryMonitorTqdm(tqdm):
765+
"""Custom tqdm class that displays memory usage in the progress bar."""
766+
767+
def format_meter(
768+
self,
769+
n,
770+
total,
771+
elapsed,
772+
postfix=None,
773+
**extra_kwargs,
774+
):
775+
"""Override to add memory usage info to the postfix."""
776+
# Get memory info
777+
memory = psutil.virtual_memory()
778+
used_gb = memory.used / (1024**3)
779+
total_gb = memory.total / (1024**3)
780+
memory_percent = memory.percent
781+
782+
# Create memory postfix
783+
memory_info = f"RAM: {used_gb:.1f}/{total_gb:.1f}GB ({memory_percent:.1f}%)"
784+
785+
# Add memory info to postfix
786+
if postfix:
787+
if isinstance(postfix, dict):
788+
postfix["memory"] = memory_info
789+
else:
790+
postfix = f"{postfix}, {memory_info}"
791+
else:
792+
postfix = memory_info
793+
794+
return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs)
795+
796+
756797
def load_orbax_checkpoint(config) -> dict:
757798
"""Loads a full Orbax checkpoint from disk with unsharded arrays.
758799
@@ -898,7 +939,9 @@ def get_hf_model(model_id: str, token: str):
898939
if model_id in ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]:
899940
from transformers import Qwen3OmniMoeForConditionalGeneration # pylint: disable=import-outside-toplevel
900941

901-
hf_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(model_id, token=token)
942+
model_class = Qwen3OmniMoeForConditionalGeneration.from_pretrained
902943
else:
903-
hf_model = AutoModelForCausalLM.from_pretrained(model_id, token=token)
944+
model_class = AutoModelForCausalLM
945+
946+
hf_model = model_class.from_pretrained(model_id, token=token)
904947
return hf_model

src/MaxText/utils/ckpt_scripts/convert_gpt_oss_ckpt.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import os
2828
import pathlib
2929
import absl
30+
import time
3031

3132
os.environ["JAX_PLATFORMS"] = "cpu"
3233

@@ -35,6 +36,7 @@
3536
import psutil
3637
from safetensors import safe_open
3738
from tqdm import tqdm
39+
from MaxText.utils.ckpt_conversion.utils.utils import MemoryMonitorTqdm, print_peak_memory
3840

3941
from MaxText import max_logging
4042
from MaxText.inference_utils import str2bool
@@ -77,7 +79,7 @@ def _convert_huggingface_to_jax_weights(
7779
max_logging.log(f"Loading the base model from {base_model_path}")
7880
ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors"))
7981
chkpt_vars = {}
80-
for i, ckpt_path in enumerate(ckpt_paths):
82+
for i, ckpt_path in tqdm(enumerate(ckpt_paths), total=len(ckpt_paths)):
8183
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
8284

8385
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
@@ -141,9 +143,9 @@ def _convert_huggingface_to_jax_weights(
141143

142144
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
143145

144-
# self attention ###############################################
146+
# layer weight: self attention ###############################################
145147
max_logging.log("Processing self attention")
146-
for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False):
148+
for layer_idx in MemoryMonitorTqdm(range(base_num_decoder_layers), desc="layers", leave=True):
147149
block_layer_idx, block_idx = divmod(layer_idx, layer_cycle_interval)
148150
stack_shape = (base_num_decoder_layers // layer_cycle_interval,)
149151
self_attention = jax_weights["decoder"]["layers"][f"layers_{block_idx}"]["GptOssAttention"]
@@ -212,9 +214,9 @@ def _convert_huggingface_to_jax_weights(
212214

213215
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
214216

215-
# layer weight pre and post self attention norm ################
217+
# layer weight: pre and post self attention norm ################
216218
max_logging.log("Processing pre and post self attention norms")
217-
for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False):
219+
for layer_idx in MemoryMonitorTqdm(range(base_num_decoder_layers), desc="layers", leave=True):
218220
block_layer_idx, block_idx = divmod(layer_idx, layer_cycle_interval)
219221
stack_shape = (base_num_decoder_layers // layer_cycle_interval,)
220222
layer_weight = jax_weights["decoder"]["layers"][f"layers_{block_idx}"]
@@ -246,10 +248,10 @@ def _convert_huggingface_to_jax_weights(
246248

247249
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
248250

249-
# layer weights ################################################
250-
max_logging.log("Processing layer weights")
251+
# layer weight: mlp ################################################
252+
max_logging.log("Processing mlp weights")
251253

252-
for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False):
254+
for layer_idx in MemoryMonitorTqdm(range(base_num_decoder_layers), desc="layers", leave=True):
253255
block_layer_idx, block_idx = divmod(layer_idx, layer_cycle_interval)
254256
stack_shape = (base_num_decoder_layers // layer_cycle_interval,)
255257
mlp_weight = jax_weights["decoder"]["layers"][f"layers_{block_idx}"]["GptOssMlp"]
@@ -337,17 +339,27 @@ def convert_to_jax_weights(base_model_path: str, model_size: str):
337339
parser.add_argument("--use-zarr3", type=str2bool, required=False, default=True)
338340
args = parser.parse_args()
339341

342+
overall_start = time.time()
343+
340344
if args.model_size not in MODEL_PARAMS_DICT:
341345
raise NotImplementedError(f"Model '{args.model_size}' is not supported.")
342346

343347
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={args.simulated_cpu_devices_count}"
344348
base_weights_path = args.maxtext_model_path
345349

350+
# transform
351+
start = time.time()
352+
weights = convert_to_jax_weights(args.base_model_path, args.model_size)
353+
max_logging.log(f"Elapse for transform: {(time.time() - start) / 60:.2f} min")
354+
355+
# save
346356
save_weights_to_checkpoint(
347357
args.maxtext_model_path,
348-
convert_to_jax_weights(args.base_model_path, args.model_size),
358+
weights,
349359
args.simulated_cpu_devices_count,
350360
args.use_ocdbt,
351361
args.use_zarr3,
352362
)
353363
max_logging.log(f"Successfully saved base_weights to {base_weights_path}.")
364+
max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min")
365+
print_peak_memory()

0 commit comments

Comments
 (0)