2828 lazy_load: (bool) If True, uses an on-demand loading strategy to minimize RAM
2929 usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM
3030 Defaults to False.
31+ --hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
32+ If unspecified, we use the default Hugging Face repository ID
33+ (e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `utils/ckpt_conversion/utils`).
34+ This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
3135
3236Environment Variables:
3337 HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to
6367from functools import partial
6468from typing import Sequence , List , Any , Callable
6569import numpy as np
66- import jax
67- import psutil
68- from flax .training import train_state
69- import flax .linen as nn
70+ import absl
71+
7072from transformers import AutoConfig
71- from tqdm import tqdm
7273from huggingface_hub import hf_hub_download , list_repo_files
7374from safetensors import safe_open
74- import absl
75-
75+ import jax
76+ import flax . linen as nn
7677from orbax .checkpoint import type_handlers
78+
7779from MaxText import max_logging
7880from MaxText import max_utils
7981from MaxText import maxtext_utils
8082from MaxText import pyconfig
8183from MaxText .common_types import MODEL_MODE_TRAIN
8284from MaxText .inference_utils import str2bool
8385from MaxText .layers import models , quantizations
86+ from MaxText .utils .ckpt_scripts .llama_or_mistral_ckpt import save_weights_to_checkpoint
8487from MaxText .utils .ckpt_conversion .utils .param_mapping import HOOK_FNS , PARAM_MAPPING
85- from MaxText .utils .ckpt_conversion .utils .utils import apply_hook_fns , HF_IDS , print_ram_usage , get_hf_model , validate_and_filter_param_map_keys
86- from maxtext .common import checkpointing
88+ from MaxText .utils .ckpt_conversion .utils .utils import (
89+ apply_hook_fns ,
90+ HF_IDS ,
91+ get_hf_model ,
92+ validate_and_filter_param_map_keys ,
93+ MemoryMonitorTqdm ,
94+ print_ram_usage ,
95+ print_peak_memory ,
96+ )
8797
88- jax .config .update ("jax_platform_name" , "cpu" )
8998
9099absl .logging .set_verbosity (absl .logging .INFO ) # for max_logging.log
91100
92101
93- class MemoryMonitorTqdm (tqdm ):
94- """Custom tqdm class that displays memory usage in the progress bar."""
95-
96- def format_meter (
97- self ,
98- n ,
99- total ,
100- elapsed ,
101- postfix = None ,
102- ** extra_kwargs ,
103- ):
104- """Override to add memory usage info to the postfix."""
105- # Get memory info
106- memory = psutil .virtual_memory ()
107- used_gb = memory .used / (1024 ** 3 )
108- total_gb = memory .total / (1024 ** 3 )
109- memory_percent = memory .percent
110-
111- # Create memory postfix
112- memory_info = f"RAM: { used_gb :.1f} /{ total_gb :.1f} GB ({ memory_percent :.1f} %)"
113-
114- # Add memory info to postfix
115- if postfix :
116- if isinstance (postfix , dict ):
117- postfix ["memory" ] = memory_info
118- else :
119- postfix = f"{ postfix } , { memory_info } "
120- else :
121- postfix = memory_info
122-
123- return super ().format_meter (n = n , total = total , elapsed = elapsed , postfix = postfix , ** extra_kwargs )
124-
125-
126102class LazyHFLoader :
127103 """
128104 Loads Hugging Face weights on-demand to minimize RAM usage.
@@ -654,20 +630,12 @@ def _eager_getter(key):
654630 hook_fn_map_mt = HOOK_FNS [model_key ](hf_config_obj .to_dict (), config , config .scan_layers , saving_to_hf = False )
655631 max_logging .log ("Parameter mappings and hooks obtained." )
656632
657- checkpoint_manager = checkpointing .create_orbax_checkpoint_manager (
658- output_directory ,
659- enable_checkpointing = True ,
660- use_async = False , # Synchronous saving for simplicity in conversion script
661- save_interval_steps = 1 , # Save at step 0
662- use_ocdbt = config .checkpoint_storage_use_ocdbt ,
663- use_zarr3 = config .checkpoint_storage_use_zarr3 ,
664- )
665-
666633 maxtext_abstract_dict , abstract_params_treedef = get_maxtext_model_info (config )
667634
668635 # Weight transformation
669636 max_logging .log ("Starting weight transformation..." )
670637 start = time .time ()
638+ # Stores MaxText weights: numpy.ndarray
671639 final_mt_weights = [None ] * len (maxtext_abstract_dict )
672640
673641 # Preprocess key
@@ -676,7 +644,7 @@ def _eager_getter(key):
676644 for mt_param_key_or_keys in MemoryMonitorTqdm (
677645 filtered_map_keys , desc = "Transforming weights" , unit = "param" , leave = True , dynamic_ncols = True
678646 ):
679- if not use_lazy_load and config . scan_layers :
647+ if not use_lazy_load :
680648 max_logging .log (f"maxtext param: { mt_param_key_or_keys } " )
681649
682650 hf_source_keys_or_key = param_map_mt_to_hf .get (mt_param_key_or_keys )
@@ -715,37 +683,34 @@ def _eager_getter(key):
715683 jax_weights = jax .tree_util .tree_unflatten (abstract_params_treedef , final_mt_weights )
716684 del final_mt_weights , abstract_params_treedef
717685
718- # Create TrainState for saving.
719- final_params_for_state = {"params" : jax_weights }
720- final_save_state = train_state .TrainState (step = 0 , apply_fn = None , params = final_params_for_state , tx = None , opt_state = {})
721- del final_params_for_state
722-
723686 print_ram_usage ("Before saving" )
724- start = time .time ()
725- if checkpoint_manager is not None :
726- if use_lazy_load :
727- max_logging .log ("Starting checkpoint save (loading weights just-in-time)..." )
728- else :
729- max_logging .log ("Starting checkpoint save..." )
730-
731- if checkpointing .save_checkpoint (checkpoint_manager , 0 , final_save_state ):
732- max_logging .log ("saved a checkpoint at step 0" )
687+ if use_lazy_load :
688+ max_logging .log ("Starting checkpoint save (loading weights just-in-time)..." )
689+ else :
690+ max_logging .log ("Starting checkpoint save..." )
733691
734- # Upon preemption, exit when and only when all ongoing saves are complete.
735- if checkpoint_manager .reached_preemption (0 ):
736- checkpoint_manager .wait_until_finished ()
737- sys .exit ()
692+ # Save the converted weights to a MaxText checkpoint.
693+ # If simulated_cpu_devices_count > 1, weights are promoted from NumPy to JAX arrays
694+ # and sharded across virtual devices.
695+ save_weights_to_checkpoint (
696+ output_directory ,
697+ jax_weights ,
698+ test_args .simulated_cpu_devices_count ,
699+ config .checkpoint_storage_use_ocdbt ,
700+ config .checkpoint_storage_use_zarr3 ,
701+ )
738702
739703 print_ram_usage ("Program Ends" )
740704 max_logging .log (f"Conversion complete. Checkpoint saved to { output_directory } " )
741- max_logging .log (f"Elapse for save: { (time .time () - start ) / 60 :.2f} min" )
742705 max_logging .log (f"Overall Elapse: { (time .time () - overall_start ) / 60 :.2f} min" )
706+ print_peak_memory ()
743707
744708
745709if __name__ == "__main__" :
746710 jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
747711 os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "0" # Suppress TensorFlow logging
748712
713+ # Define local parser
749714 parser = argparse .ArgumentParser ()
750715 parser .add_argument (
751716 "--lazy_load_tensors" ,
@@ -754,13 +719,32 @@ def _eager_getter(key):
754719 default = False ,
755720 help = "Whether to use lazy loading of HF tensors." ,
756721 )
757- # if not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name]
722+ # If not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name]
758723 parser .add_argument (
759724 "--hf_model_path" , type = str , required = False , default = "" , help = "local path to hf model, or custom remote hf repo"
760725 )
761- local_args , _ = parser .parse_known_args ()
762- model_args = sys .argv
763- to_remove_args = ["--lazy_load_tensors" , "--hf_model_path" ]
764- for a in to_remove_args :
765- model_args = [s for s in model_args if not s .startswith (a )]
726+ # Determines the logical sharding of the output checkpoint by partitioning
727+ # weights across virtual XLA devices.
728+ # - Even on a single CPU host, JAX can simulate multiple devices (e.g., 16)
729+ # - If set to 1, sharding is skipped.
730+ # - Sharding is preferred. For downstream loading on TPU pods, this helps prevent OOM and speedup.
731+ #
732+ # Example: Embedding Layer shape=(151936, 1024)
733+ # Case 1: simulated_cpu_devices_count=16 (Sharded)
734+ # sharding: NamedShardingMetadata(shape=[16], ...)
735+ # storage: chunk_shape=(9496, 1024) <-- 1/16th of rows per chunk
736+ # Case 2: simulated_cpu_devices_count=1 (Monolith)
737+ # sharding: None
738+ # storage: chunk_shape=(151936, 1024) <-- Full layer in one chunk
739+ parser .add_argument ("--simulated_cpu_devices_count" , type = int , required = False , default = 16 )
740+
741+ # Parse local arguments
742+ # Parse known args returns the namespace AND the list of remaining arguments
743+ local_args , remaining_args = parser .parse_known_args ()
744+ # Reconstruct model_args (script name + the args MaxText needs)
745+ model_args = [sys .argv [0 ]] + remaining_args
746+
747+ # Set jax environment
748+ jax .config .update ("jax_platforms" , "cpu" )
749+ os .environ ["XLA_FLAGS" ] = f"--xla_force_host_platform_device_count={ local_args .simulated_cpu_devices_count } "
766750 main (model_args , local_args )
0 commit comments