2828 --lazy_load_tensors: (bool) If True, uses an on-demand loading strategy to minimize RAM
2929 usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM
3030 Defaults to False.
31- --hf_model_path: (Optional) Specify a local HF path, rather than the default repo `HF_IDS[model_name]`.
32- Useful for locally dequantized HF model like GPT-OSS or DeepSeek.
31+ --hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
32+ If unspecified, we use the default Hugging Face repository ID
33+ (e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `utils/ckpt_conversion/utils`).
34+ This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
3335
3436Environment Variables:
3537 HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to
6567from functools import partial
6668from typing import Sequence , List , Any , Callable
6769import numpy as np
68- import jax
69- import psutil
70- from flax .training import train_state
71- import flax .linen as nn
70+ import absl
71+
7272from transformers import AutoConfig
73- from tqdm import tqdm
7473from huggingface_hub import hf_hub_download , list_repo_files
7574from safetensors import safe_open
76- import absl
77-
75+ import jax
76+ import flax . linen as nn
7877from orbax .checkpoint import type_handlers
79- from MaxText import checkpointing
78+
8079from MaxText import max_logging
8180from MaxText import max_utils
8281from MaxText import maxtext_utils
8382from MaxText import pyconfig
8483from MaxText .common_types import MODEL_MODE_TRAIN
8584from MaxText .inference_utils import str2bool
8685from MaxText .layers import models , quantizations
87- from MaxText .checkpointing import save_checkpoint
86+ from MaxText .utils . ckpt_scripts . llama_or_mistral_ckpt import save_weights_to_checkpoint
8887from MaxText .utils .ckpt_conversion .utils .param_mapping import HOOK_FNS , PARAM_MAPPING
89- from MaxText .utils .ckpt_conversion .utils .utils import apply_hook_fns , HF_IDS , print_ram_usage , get_hf_model , validate_and_filter_param_map_keys
88+ from MaxText .utils .ckpt_conversion .utils .utils import (
89+ apply_hook_fns ,
90+ HF_IDS ,
91+ get_hf_model ,
92+ validate_and_filter_param_map_keys ,
93+ MemoryMonitorTqdm ,
94+ print_ram_usage ,
95+ print_peak_memory ,
96+ )
9097
91- jax .config .update ("jax_platform_name" , "cpu" )
9298
9399absl .logging .set_verbosity (absl .logging .INFO ) # for max_logging.log
94100
95101
96- class MemoryMonitorTqdm (tqdm ):
97- """Custom tqdm class that displays memory usage in the progress bar."""
98-
99- def format_meter (
100- self ,
101- n ,
102- total ,
103- elapsed ,
104- postfix = None ,
105- ** extra_kwargs ,
106- ):
107- """Override to add memory usage info to the postfix."""
108- # Get memory info
109- memory = psutil .virtual_memory ()
110- used_gb = memory .used / (1024 ** 3 )
111- total_gb = memory .total / (1024 ** 3 )
112- memory_percent = memory .percent
113-
114- # Create memory postfix
115- memory_info = f"RAM: { used_gb :.1f} /{ total_gb :.1f} GB ({ memory_percent :.1f} %)"
116-
117- # Add memory info to postfix
118- if postfix :
119- if isinstance (postfix , dict ):
120- postfix ["memory" ] = memory_info
121- else :
122- postfix = f"{ postfix } , { memory_info } "
123- else :
124- postfix = memory_info
125-
126- return super ().format_meter (n = n , total = total , elapsed = elapsed , postfix = postfix , ** extra_kwargs )
127-
128-
129102class LazyHFLoader :
130103 """
131104 Loads Hugging Face weights on-demand to minimize RAM usage.
@@ -657,20 +630,12 @@ def _eager_getter(key):
657630 hook_fn_map_mt = HOOK_FNS [model_key ](hf_config_obj .to_dict (), config , config .scan_layers , saving_to_hf = False )
658631 max_logging .log ("Parameter mappings and hooks obtained." )
659632
660- checkpoint_manager = checkpointing .create_orbax_checkpoint_manager (
661- output_directory ,
662- enable_checkpointing = True ,
663- use_async = False , # Synchronous saving for simplicity in conversion script
664- save_interval_steps = 1 , # Save at step 0
665- use_ocdbt = config .checkpoint_storage_use_ocdbt ,
666- use_zarr3 = config .checkpoint_storage_use_zarr3 ,
667- )
668-
669633 maxtext_abstract_dict , abstract_params_treedef = get_maxtext_model_info (config )
670634
671635 # Weight transformation
672636 max_logging .log ("Starting weight transformation..." )
673637 start = time .time ()
638+ # Stores MaxText weights: numpy.ndarray
674639 final_mt_weights = [None ] * len (maxtext_abstract_dict )
675640
676641 # Preprocess key
@@ -679,7 +644,7 @@ def _eager_getter(key):
679644 for mt_param_key_or_keys in MemoryMonitorTqdm (
680645 filtered_map_keys , desc = "Transforming weights" , unit = "param" , leave = True , dynamic_ncols = True
681646 ):
682- if not use_lazy_load and config . scan_layers :
647+ if not use_lazy_load :
683648 max_logging .log (f"maxtext param: { mt_param_key_or_keys } " )
684649
685650 hf_source_keys_or_key = param_map_mt_to_hf .get (mt_param_key_or_keys )
@@ -718,31 +683,24 @@ def _eager_getter(key):
718683 jax_weights = jax .tree_util .tree_unflatten (abstract_params_treedef , final_mt_weights )
719684 del final_mt_weights , abstract_params_treedef
720685
721- # Create TrainState for saving.
722- final_params_for_state = {"params" : jax_weights }
723- final_save_state = train_state .TrainState (step = 0 , apply_fn = None , params = final_params_for_state , tx = None , opt_state = {})
724- del final_params_for_state
725-
726686 print_ram_usage ("Before saving" )
727- start = time .time ()
728- if checkpoint_manager is not None :
729- if use_lazy_load :
730- max_logging .log ("Starting checkpoint save (loading weights just-in-time)..." )
731- else :
732- max_logging .log ("Starting checkpoint save..." )
733-
734- if save_checkpoint (checkpoint_manager , 0 , final_save_state ):
735- max_logging .log ("saved a checkpoint at step 0" )
687+ if use_lazy_load :
688+ max_logging .log ("Starting checkpoint save (loading weights just-in-time)..." )
689+ else :
690+ max_logging .log ("Starting checkpoint save..." )
736691
737- # Upon preemption, exit when and only when all ongoing saves are complete.
738- if checkpoint_manager .reached_preemption (0 ):
739- checkpoint_manager .wait_until_finished ()
740- sys .exit ()
692+ save_weights_to_checkpoint (
693+ output_directory ,
694+ jax_weights ,
695+ test_args .simulated_cpu_devices_count ,
696+ config .checkpoint_storage_use_ocdbt ,
697+ config .checkpoint_storage_use_zarr3 ,
698+ )
741699
742700 print_ram_usage ("Program Ends" )
743701 max_logging .log (f"Conversion complete. Checkpoint saved to { output_directory } " )
744- max_logging .log (f"Elapse for save: { (time .time () - start ) / 60 :.2f} min" )
745702 max_logging .log (f"Overall Elapse: { (time .time () - overall_start ) / 60 :.2f} min" )
703+ print_peak_memory ()
746704
747705
748706if __name__ == "__main__" :
@@ -757,13 +715,19 @@ def _eager_getter(key):
757715 default = False ,
758716 help = "Whether to use lazy loading of HF tensors." ,
759717 )
760- # if not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name]
718+ # If not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name]
761719 parser .add_argument (
762720 "--hf_model_path" , type = str , required = False , default = "" , help = "local path to hf model, or custom remote hf repo"
763721 )
722+ # Used to convert numpy weights to sharded jax arrays across simulated cpu devices
723+ # If count=1, do not shard
724+ parser .add_argument ("--simulated_cpu_devices_count" , type = int , required = False , default = 16 )
764725 local_args , _ = parser .parse_known_args ()
765726 model_args = sys .argv
766- to_remove_args = ["--lazy_load_tensors" , "--hf_model_path" ]
727+ to_remove_args = ["--lazy_load_tensors" , "--hf_model_path" , "--simulated_cpu_devices_count" ]
767728 for a in to_remove_args :
768729 model_args = [s for s in model_args if not s .startswith (a )]
730+
731+ os .environ ["JAX_PLATFORMS" ] = "cpu"
732+ os .environ ["XLA_FLAGS" ] = f"--xla_force_host_platform_device_count={ local_args .simulated_cpu_devices_count } "
769733 main (model_args , local_args )
0 commit comments