2828 --lazy_load_tensors: (bool) If True, uses an on-demand loading strategy to minimize RAM
2929 usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM
3030 Defaults to False.
31- --hf_model_path: (Optional) Specify a local HF path, rather than the default repo `HF_IDS[model_name]`.
32- Useful for locally dequantized HF model like GPT-OSS or DeepSeek.
31+ --hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
32+ If unspecified, we use the default Hugging Face repository ID
33+ (e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `utils/ckpt_conversion/utils`).
34+ This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
3335
3436Environment Variables:
3537 HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to
6567from functools import partial
6668from typing import Sequence , List , Any , Callable
6769import numpy as np
68- import jax
69- import psutil
70- from flax .training import train_state
71- import flax .linen as nn
70+ import absl
71+
7272from transformers import AutoConfig
73- from tqdm import tqdm
7473from huggingface_hub import hf_hub_download , list_repo_files
7574from safetensors import safe_open
76- import absl
77-
75+ import jax
76+ import flax . linen as nn
7877from orbax .checkpoint import type_handlers
79- from MaxText import checkpointing
8078from MaxText import max_logging
8179from MaxText import max_utils
8280from MaxText import maxtext_utils
8381from MaxText import pyconfig
8482from MaxText .common_types import MODEL_MODE_TRAIN
8583from MaxText .inference_utils import str2bool
8684from MaxText .layers import models , quantizations
87- from MaxText .checkpointing import save_checkpoint
85+ from MaxText .utils . ckpt_scripts . llama_or_mistral_ckpt import save_weights_to_checkpoint
8886from MaxText .utils .ckpt_conversion .utils .param_mapping import HOOK_FNS , PARAM_MAPPING
89- from MaxText .utils .ckpt_conversion .utils .utils import apply_hook_fns , HF_IDS , print_ram_usage , get_hf_model , validate_and_filter_param_map_keys
87+ from MaxText .utils .ckpt_conversion .utils .utils import (
88+ apply_hook_fns ,
89+ HF_IDS ,
90+ get_hf_model ,
91+ validate_and_filter_param_map_keys ,
92+ MemoryMonitorTqdm ,
93+ print_ram_usage ,
94+ print_peak_memory ,
95+ )
9096
91- jax .config .update ("jax_platform_name" , "cpu" )
9297
9398absl .logging .set_verbosity (absl .logging .INFO ) # for max_logging.log
9499
95100
96- class MemoryMonitorTqdm (tqdm ):
97- """Custom tqdm class that displays memory usage in the progress bar."""
98-
99- def format_meter (
100- self ,
101- n ,
102- total ,
103- elapsed ,
104- postfix = None ,
105- ** extra_kwargs ,
106- ):
107- """Override to add memory usage info to the postfix."""
108- # Get memory info
109- memory = psutil .virtual_memory ()
110- used_gb = memory .used / (1024 ** 3 )
111- total_gb = memory .total / (1024 ** 3 )
112- memory_percent = memory .percent
113-
114- # Create memory postfix
115- memory_info = f"RAM: { used_gb :.1f} /{ total_gb :.1f} GB ({ memory_percent :.1f} %)"
116-
117- # Add memory info to postfix
118- if postfix :
119- if isinstance (postfix , dict ):
120- postfix ["memory" ] = memory_info
121- else :
122- postfix = f"{ postfix } , { memory_info } "
123- else :
124- postfix = memory_info
125-
126- return super ().format_meter (n = n , total = total , elapsed = elapsed , postfix = postfix , ** extra_kwargs )
101+ # Used to convert numpy weights to sharded jax arrays across simulated cpu devices
102+ SIMULATED_CPU_DEVICES_COUNT = 16
103+ os .environ ["JAX_PLATFORMS" ] = "cpu"
104+ os .environ ["XLA_FLAGS" ] = f"--xla_force_host_platform_device_count={ SIMULATED_CPU_DEVICES_COUNT } "
127105
128106
129107class LazyHFLoader :
@@ -657,20 +635,12 @@ def _eager_getter(key):
657635 hook_fn_map_mt = HOOK_FNS [model_key ](hf_config_obj .to_dict (), config , config .scan_layers , saving_to_hf = False )
658636 max_logging .log ("Parameter mappings and hooks obtained." )
659637
660- checkpoint_manager = checkpointing .create_orbax_checkpoint_manager (
661- output_directory ,
662- enable_checkpointing = True ,
663- use_async = False , # Synchronous saving for simplicity in conversion script
664- save_interval_steps = 1 , # Save at step 0
665- use_ocdbt = config .checkpoint_storage_use_ocdbt ,
666- use_zarr3 = config .checkpoint_storage_use_zarr3 ,
667- )
668-
669638 maxtext_abstract_dict , abstract_params_treedef = get_maxtext_model_info (config )
670639
671640 # Weight transformation
672641 max_logging .log ("Starting weight transformation..." )
673642 start = time .time ()
643+ # Stores MaxText weights: numpy.ndarray
674644 final_mt_weights = [None ] * len (maxtext_abstract_dict )
675645
676646 # Preprocess key
@@ -679,7 +649,7 @@ def _eager_getter(key):
679649 for mt_param_key_or_keys in MemoryMonitorTqdm (
680650 filtered_map_keys , desc = "Transforming weights" , unit = "param" , leave = True , dynamic_ncols = True
681651 ):
682- if not use_lazy_load and config . scan_layers :
652+ if not use_lazy_load :
683653 max_logging .log (f"maxtext param: { mt_param_key_or_keys } " )
684654
685655 hf_source_keys_or_key = param_map_mt_to_hf .get (mt_param_key_or_keys )
@@ -718,31 +688,24 @@ def _eager_getter(key):
718688 jax_weights = jax .tree_util .tree_unflatten (abstract_params_treedef , final_mt_weights )
719689 del final_mt_weights , abstract_params_treedef
720690
721- # Create TrainState for saving.
722- final_params_for_state = {"params" : jax_weights }
723- final_save_state = train_state .TrainState (step = 0 , apply_fn = None , params = final_params_for_state , tx = None , opt_state = {})
724- del final_params_for_state
725-
726691 print_ram_usage ("Before saving" )
727- start = time .time ()
728- if checkpoint_manager is not None :
729- if use_lazy_load :
730- max_logging .log ("Starting checkpoint save (loading weights just-in-time)..." )
731- else :
732- max_logging .log ("Starting checkpoint save..." )
733-
734- if save_checkpoint (checkpoint_manager , 0 , final_save_state ):
735- max_logging .log ("saved a checkpoint at step 0" )
692+ if use_lazy_load :
693+ max_logging .log ("Starting checkpoint save (loading weights just-in-time)..." )
694+ else :
695+ max_logging .log ("Starting checkpoint save..." )
736696
737- # Upon preemption, exit when and only when all ongoing saves are complete.
738- if checkpoint_manager .reached_preemption (0 ):
739- checkpoint_manager .wait_until_finished ()
740- sys .exit ()
697+ save_weights_to_checkpoint (
698+ output_directory ,
699+ jax_weights ,
700+ SIMULATED_CPU_DEVICES_COUNT ,
701+ config .checkpoint_storage_use_ocdbt ,
702+ config .checkpoint_storage_use_zarr3 ,
703+ )
741704
742705 print_ram_usage ("Program Ends" )
743706 max_logging .log (f"Conversion complete. Checkpoint saved to { output_directory } " )
744- max_logging .log (f"Elapse for save: { (time .time () - start ) / 60 :.2f} min" )
745707 max_logging .log (f"Overall Elapse: { (time .time () - overall_start ) / 60 :.2f} min" )
708+ print_peak_memory ()
746709
747710
748711if __name__ == "__main__" :
@@ -757,7 +720,7 @@ def _eager_getter(key):
757720 default = False ,
758721 help = "Whether to use lazy loading of HF tensors." ,
759722 )
760- # if not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name]
723+ # If not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name]
761724 parser .add_argument (
762725 "--hf_model_path" , type = str , required = False , default = "" , help = "local path to hf model, or custom remote hf repo"
763726 )
0 commit comments