From 4e68630dafadc44b9df634244291263fe94117a9 Mon Sep 17 00:00:00 2001 From: Shuning Jin Date: Wed, 28 Jan 2026 23:52:58 +0000 Subject: [PATCH] checkpoint utility: shard checkpoint, monitor peak --- .../convert_checkpoint.md | 2 +- .../utils/ckpt_conversion/to_huggingface.py | 12 +- .../utils/ckpt_conversion/to_maxtext.py | 140 ++++++++---------- .../utils/ckpt_conversion/utils/utils.py | 47 +++++- .../ckpt_scripts/convert_gpt_oss_ckpt.py | 30 ++-- .../ckpt_scripts/llama_or_mistral_ckpt.py | 109 +++++++++++--- 6 files changed, 223 insertions(+), 117 deletions(-) diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index 6d5a588312..b37d2923c8 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -90,7 +90,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \ - `checkpoint_storage_use_zarr3`: # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways. - `checkpoint_storage_use_ocdbt`: # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways. - `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. For large models, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes. -- `--hf_model_path` (optional): Specifies a local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py#L58-L85) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. +- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/0d909c44391539db4e8cc2a33de9d77a891beb31/src/MaxText/utils/ckpt_conversion/utils/utils.py#L58-L85) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. Above command will download the Hugging Face model to local machine, convert it to the MaxText format and save it to `${MODEL_CHECKPOINT_DIRECTORY}/0/items`. diff --git a/src/MaxText/utils/ckpt_conversion/to_huggingface.py b/src/MaxText/utils/ckpt_conversion/to_huggingface.py index 3aa51f2d99..622a9f02fb 100644 --- a/src/MaxText/utils/ckpt_conversion/to_huggingface.py +++ b/src/MaxText/utils/ckpt_conversion/to_huggingface.py @@ -55,7 +55,6 @@ import os from typing import Sequence import time -from tqdm import tqdm from transformers import AutoTokenizer, AutoProcessor @@ -77,11 +76,10 @@ load_orbax_checkpoint, detect_and_extract_checkpoint, HF_IDS, + MemoryMonitorTqdm, + print_peak_memory, ) -os.environ["JAX_PLATFORMS"] = "cpu" -os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16" - def _get_model_mappings( model_name: str, scan_layers: bool, hf_config_dict: dict, maxtext_config: pyconfig.HyperParameters @@ -125,6 +123,9 @@ def main(argv: Sequence[str]) -> None: jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + jax.config.update("jax_platforms", "cpu") + os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16" + # Initialize maxtext config config = pyconfig.initialize(argv) assert ( @@ -179,7 +180,7 @@ def main(argv: Sequence[str]) -> None: start = time.time() processed_params_list = [] - for key in tqdm(filtered_map_keys, total=len(filtered_map_keys)): + for key in MemoryMonitorTqdm(filtered_map_keys, total=len(filtered_map_keys), leave=True): if isinstance(key, tuple): # if key is tuple of param names, weight is list of param weights weight = [maxtext_state_dict[subkey] for subkey in key] @@ -210,6 +211,7 @@ def main(argv: Sequence[str]) -> None: max_logging.log(f"✅ MaxText model successfully saved in HuggingFace format at {output_directory}") max_logging.log(f"Elapse for save: {(time.time() - start) / 60:.2f} min") max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min") + print_peak_memory() if __name__ == "__main__": diff --git a/src/MaxText/utils/ckpt_conversion/to_maxtext.py b/src/MaxText/utils/ckpt_conversion/to_maxtext.py index f002a15bf9..160251e533 100644 --- a/src/MaxText/utils/ckpt_conversion/to_maxtext.py +++ b/src/MaxText/utils/ckpt_conversion/to_maxtext.py @@ -28,6 +28,10 @@ lazy_load: (bool) If True, uses an on-demand loading strategy to minimize RAM usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM Defaults to False. + --hf_model_path: (Optional) Specifies a local or remote directory containing the model weights. + If unspecified, we use the default Hugging Face repository ID + (e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `utils/ckpt_conversion/utils`). + This is necessary for locally dequantized models like GPT-OSS or DeepSeek. Environment Variables: HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to @@ -63,66 +67,38 @@ from functools import partial from typing import Sequence, List, Any, Callable import numpy as np -import jax -import psutil -from flax.training import train_state -import flax.linen as nn +import absl + from transformers import AutoConfig -from tqdm import tqdm from huggingface_hub import hf_hub_download, list_repo_files from safetensors import safe_open -import absl - +import jax +import flax.linen as nn from orbax.checkpoint import type_handlers + from MaxText import max_logging from MaxText import max_utils from MaxText import maxtext_utils from MaxText import pyconfig from MaxText.common_types import MODEL_MODE_TRAIN from MaxText.layers import models, quantizations +from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint from MaxText.utils.ckpt_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING -from MaxText.utils.ckpt_conversion.utils.utils import apply_hook_fns, HF_IDS, print_ram_usage, get_hf_model, validate_and_filter_param_map_keys +from MaxText.utils.ckpt_conversion.utils.utils import ( + apply_hook_fns, + HF_IDS, + get_hf_model, + validate_and_filter_param_map_keys, + MemoryMonitorTqdm, + print_ram_usage, + print_peak_memory, +) from maxtext.inference.inference_utils import str2bool -from maxtext.common import checkpointing -jax.config.update("jax_platform_name", "cpu") absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log -class MemoryMonitorTqdm(tqdm): - """Custom tqdm class that displays memory usage in the progress bar.""" - - def format_meter( - self, - n, - total, - elapsed, - postfix=None, - **extra_kwargs, - ): - """Override to add memory usage info to the postfix.""" - # Get memory info - memory = psutil.virtual_memory() - used_gb = memory.used / (1024**3) - total_gb = memory.total / (1024**3) - memory_percent = memory.percent - - # Create memory postfix - memory_info = f"RAM: {used_gb:.1f}/{total_gb:.1f}GB ({memory_percent:.1f}%)" - - # Add memory info to postfix - if postfix: - if isinstance(postfix, dict): - postfix["memory"] = memory_info - else: - postfix = f"{postfix}, {memory_info}" - else: - postfix = memory_info - - return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs) - - class LazyHFLoader: """ Loads Hugging Face weights on-demand to minimize RAM usage. @@ -654,20 +630,12 @@ def _eager_getter(key): hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config, config.scan_layers, saving_to_hf=False) max_logging.log("Parameter mappings and hooks obtained.") - checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( - output_directory, - enable_checkpointing=True, - use_async=False, # Synchronous saving for simplicity in conversion script - save_interval_steps=1, # Save at step 0 - use_ocdbt=config.checkpoint_storage_use_ocdbt, - use_zarr3=config.checkpoint_storage_use_zarr3, - ) - maxtext_abstract_dict, abstract_params_treedef = get_maxtext_model_info(config) # Weight transformation max_logging.log("Starting weight transformation...") start = time.time() + # Stores MaxText weights: numpy.ndarray final_mt_weights = [None] * len(maxtext_abstract_dict) # Preprocess key @@ -676,7 +644,7 @@ def _eager_getter(key): for mt_param_key_or_keys in MemoryMonitorTqdm( filtered_map_keys, desc="Transforming weights", unit="param", leave=True, dynamic_ncols=True ): - if not use_lazy_load and config.scan_layers: + if not use_lazy_load: max_logging.log(f"maxtext param: {mt_param_key_or_keys}") hf_source_keys_or_key = param_map_mt_to_hf.get(mt_param_key_or_keys) @@ -715,37 +683,34 @@ def _eager_getter(key): jax_weights = jax.tree_util.tree_unflatten(abstract_params_treedef, final_mt_weights) del final_mt_weights, abstract_params_treedef - # Create TrainState for saving. - final_params_for_state = {"params": jax_weights} - final_save_state = train_state.TrainState(step=0, apply_fn=None, params=final_params_for_state, tx=None, opt_state={}) - del final_params_for_state - print_ram_usage("Before saving") - start = time.time() - if checkpoint_manager is not None: - if use_lazy_load: - max_logging.log("Starting checkpoint save (loading weights just-in-time)...") - else: - max_logging.log("Starting checkpoint save...") - - if checkpointing.save_checkpoint(checkpoint_manager, 0, final_save_state): - max_logging.log("saved a checkpoint at step 0") + if use_lazy_load: + max_logging.log("Starting checkpoint save (loading weights just-in-time)...") + else: + max_logging.log("Starting checkpoint save...") - # Upon preemption, exit when and only when all ongoing saves are complete. - if checkpoint_manager.reached_preemption(0): - checkpoint_manager.wait_until_finished() - sys.exit() + # Save the converted weights to a MaxText checkpoint. + # If simulated_cpu_devices_count > 1, weights are promoted from NumPy to JAX arrays + # and sharded across virtual devices. + save_weights_to_checkpoint( + output_directory, + jax_weights, + test_args.simulated_cpu_devices_count, + config.checkpoint_storage_use_ocdbt, + config.checkpoint_storage_use_zarr3, + ) print_ram_usage("Program Ends") max_logging.log(f"Conversion complete. Checkpoint saved to {output_directory}") - max_logging.log(f"Elapse for save: {(time.time() - start) / 60:.2f} min") max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min") + print_peak_memory() if __name__ == "__main__": jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Suppress TensorFlow logging + # Define local parser parser = argparse.ArgumentParser() parser.add_argument( "--lazy_load_tensors", @@ -754,13 +719,32 @@ def _eager_getter(key): default=False, help="Whether to use lazy loading of HF tensors.", ) - # if not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name] + # If not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name] parser.add_argument( "--hf_model_path", type=str, required=False, default="", help="local path to hf model, or custom remote hf repo" ) - local_args, _ = parser.parse_known_args() - model_args = sys.argv - to_remove_args = ["--lazy_load_tensors", "--hf_model_path"] - for a in to_remove_args: - model_args = [s for s in model_args if not s.startswith(a)] + # Determines the logical sharding of the output checkpoint by partitioning + # weights across virtual XLA devices. + # - Even on a single CPU host, JAX can simulate multiple devices (e.g., 16) + # - If set to 1, sharding is skipped. + # - Sharding is preferred. For downstream loading on TPU pods, this helps prevent OOM and speedup. + # + # Example: Embedding Layer shape=(151936, 1024) + # Case 1: simulated_cpu_devices_count=16 (Sharded) + # sharding: NamedShardingMetadata(shape=[16], ...) + # storage: chunk_shape=(9496, 1024) <-- 1/16th of rows per chunk + # Case 2: simulated_cpu_devices_count=1 (Monolith) + # sharding: None + # storage: chunk_shape=(151936, 1024) <-- Full layer in one chunk + parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16) + + # Parse local arguments + # Parse known args returns the namespace AND the list of remaining arguments + local_args, remaining_args = parser.parse_known_args() + # Reconstruct model_args (script name + the args MaxText needs) + model_args = [sys.argv[0]] + remaining_args + + # Set jax environment + jax.config.update("jax_platforms", "cpu") + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}" main(model_args, local_args) diff --git a/src/MaxText/utils/ckpt_conversion/utils/utils.py b/src/MaxText/utils/ckpt_conversion/utils/utils.py index ca514f649e..8b4027e907 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/utils.py +++ b/src/MaxText/utils/ckpt_conversion/utils/utils.py @@ -22,6 +22,8 @@ import json from concurrent.futures import ThreadPoolExecutor from typing import Any +from tqdm import tqdm +import resource import jax from jax.experimental import multihost_utils @@ -753,6 +755,45 @@ def print_ram_usage(stage=""): ) +def print_peak_memory(): + # Returns peak usage in Kilobytes on Linux + peak_memory_kb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + max_logging.log(f"Peak Memory: {peak_memory_kb / 1024**2:.2f} GB") + + +class MemoryMonitorTqdm(tqdm): + """Custom tqdm class that displays memory usage in the progress bar.""" + + def format_meter( + self, + n, + total, + elapsed, + postfix=None, + **extra_kwargs, + ): + """Override to add memory usage info to the postfix.""" + # Get memory info + memory = psutil.virtual_memory() + used_gb = memory.used / (1024**3) + total_gb = memory.total / (1024**3) + memory_percent = memory.percent + + # Create memory postfix + memory_info = f"RAM: {used_gb:.1f}/{total_gb:.1f}GB ({memory_percent:.1f}%)" + + # Add memory info to postfix + if postfix: + if isinstance(postfix, dict): + postfix["memory"] = memory_info + else: + postfix = f"{postfix}, {memory_info}" + else: + postfix = memory_info + + return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs) + + def load_orbax_checkpoint(config) -> dict: """Loads a full Orbax checkpoint from disk with unsharded arrays. @@ -898,7 +939,9 @@ def get_hf_model(model_id: str, token: str): if model_id in ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]: from transformers import Qwen3OmniMoeForConditionalGeneration # pylint: disable=import-outside-toplevel - hf_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(model_id, token=token) + model_class = Qwen3OmniMoeForConditionalGeneration.from_pretrained else: - hf_model = AutoModelForCausalLM.from_pretrained(model_id, token=token) + model_class = AutoModelForCausalLM + + hf_model = model_class.from_pretrained(model_id, token=token) return hf_model diff --git a/src/MaxText/utils/ckpt_scripts/convert_gpt_oss_ckpt.py b/src/MaxText/utils/ckpt_scripts/convert_gpt_oss_ckpt.py index d6a8ac3e3f..5c5088d4b9 100644 --- a/src/MaxText/utils/ckpt_scripts/convert_gpt_oss_ckpt.py +++ b/src/MaxText/utils/ckpt_scripts/convert_gpt_oss_ckpt.py @@ -27,6 +27,7 @@ import os import pathlib import absl +import time os.environ["JAX_PLATFORMS"] = "cpu" @@ -39,6 +40,7 @@ from MaxText import max_logging from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint from MaxText.utils.ckpt_scripts.convert_gpt_oss_unscanned_ckpt import MODEL_PARAMS_DICT, _hf_to_maxtext_mapping, _pt_to_np +from MaxText.utils.ckpt_conversion.utils.utils import MemoryMonitorTqdm, print_peak_memory from maxtext.inference.inference_utils import str2bool absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log @@ -77,7 +79,7 @@ def _convert_huggingface_to_jax_weights( max_logging.log(f"Loading the base model from {base_model_path}") ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors")) chkpt_vars = {} - for i, ckpt_path in enumerate(ckpt_paths): + for i, ckpt_path in tqdm(enumerate(ckpt_paths), total=len(ckpt_paths)): max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...") with safe_open(ckpt_path, framework="pt", device="cpu") as f: @@ -141,9 +143,9 @@ def _convert_huggingface_to_jax_weights( logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) - # self attention ############################################### + # layer weight: self attention ############################################### max_logging.log("Processing self attention") - for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False): + for layer_idx in MemoryMonitorTqdm(range(base_num_decoder_layers), desc="layers", leave=True): block_layer_idx, block_idx = divmod(layer_idx, layer_cycle_interval) stack_shape = (base_num_decoder_layers // layer_cycle_interval,) self_attention = jax_weights["decoder"]["layers"][f"layers_{block_idx}"]["GptOssAttention"] @@ -212,9 +214,9 @@ def _convert_huggingface_to_jax_weights( logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) - # layer weight pre and post self attention norm ################ + # layer weight: pre and post self attention norm ################ max_logging.log("Processing pre and post self attention norms") - for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False): + for layer_idx in MemoryMonitorTqdm(range(base_num_decoder_layers), desc="layers", leave=True): block_layer_idx, block_idx = divmod(layer_idx, layer_cycle_interval) stack_shape = (base_num_decoder_layers // layer_cycle_interval,) layer_weight = jax_weights["decoder"]["layers"][f"layers_{block_idx}"] @@ -246,10 +248,10 @@ def _convert_huggingface_to_jax_weights( logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) - # layer weights ################################################ - max_logging.log("Processing layer weights") + # layer weight: mlp ################################################ + max_logging.log("Processing mlp weights") - for layer_idx in tqdm(range(base_num_decoder_layers), desc="layers", leave=False): + for layer_idx in MemoryMonitorTqdm(range(base_num_decoder_layers), desc="layers", leave=True): block_layer_idx, block_idx = divmod(layer_idx, layer_cycle_interval) stack_shape = (base_num_decoder_layers // layer_cycle_interval,) mlp_weight = jax_weights["decoder"]["layers"][f"layers_{block_idx}"]["GptOssMlp"] @@ -337,17 +339,27 @@ def convert_to_jax_weights(base_model_path: str, model_size: str): parser.add_argument("--use-zarr3", type=str2bool, required=False, default=True) args = parser.parse_args() + overall_start = time.time() + if args.model_size not in MODEL_PARAMS_DICT: raise NotImplementedError(f"Model '{args.model_size}' is not supported.") os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={args.simulated_cpu_devices_count}" base_weights_path = args.maxtext_model_path + # transform + start = time.time() + weights = convert_to_jax_weights(args.base_model_path, args.model_size) + max_logging.log(f"Elapse for transform: {(time.time() - start) / 60:.2f} min") + + # save save_weights_to_checkpoint( args.maxtext_model_path, - convert_to_jax_weights(args.base_model_path, args.model_size), + weights, args.simulated_cpu_devices_count, args.use_ocdbt, args.use_zarr3, ) max_logging.log(f"Successfully saved base_weights to {base_weights_path}.") + max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min") + print_peak_memory() diff --git a/src/MaxText/utils/ckpt_scripts/llama_or_mistral_ckpt.py b/src/MaxText/utils/ckpt_scripts/llama_or_mistral_ckpt.py index cec9ce5663..b51d060c47 100644 --- a/src/MaxText/utils/ckpt_scripts/llama_or_mistral_ckpt.py +++ b/src/MaxText/utils/ckpt_scripts/llama_or_mistral_ckpt.py @@ -48,6 +48,7 @@ import psutil from tqdm import tqdm +import time import numpy as np @@ -1631,52 +1632,114 @@ def convert_to_jax_weights(base_model_path: str, model_size: str, huggingface_ck return _convert_pytorch_to_jax_weights(base_model_path, model_size, model_params, mem_info) -def save_weights_to_checkpoint( - maxtext_model_path: str, jax_weights: dict, device_count: int, use_ocdbt: bool, use_zarr3: bool -): - """ - Function to save jax_weights ready for MaxText to a parameters checkpoint. +def shard_checkpoint(jax_weights, device_count, mem_info): + """Shards the checkpoint weights across the simulated devices. Args: - maxtext_model_path: Path to save the MaxText checkpoint. - jax_weights: The JAX model weights to be saved. - device_count: The number of simulated devices. - use_ocdbt: Whether to use Optimized Checkpoint Database with Transactions. - use_zarr3: Whether to use Zarr3 or not. + jax_weights: Pytree of model weights (numpy arrays). + device_count: The number of simulated devices. + mem_info: Process object to track memory usage. + + Returns: + Pytree of sharded JAX arrays. """ - mem_info = psutil.Process() - logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) - gc.collect() + # Setup mesh & sharding specs + if len(jax.devices()) != device_count: + max_logging.log( + "WARNING: hardware/simulated device mismatch. " + f"Actual JAX devices: {len(jax.devices())}, Requested count: {device_count}." + ) + max_logging.log(f"shard weights across {len(jax.devices())} devices") + # Pre-define sharding specs mesh = jax.sharding.Mesh(jax.devices(), "checkpoint_sharding_axis") - s1 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("checkpoint_sharding_axis")) # shards first axis - s2 = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec(None, "checkpoint_sharding_axis") - ) # shards second axis - s3 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) # no sharding + # Sharding along axis 0 + s1 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("checkpoint_sharding_axis")) + # Sharding along axis 1 + s2 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "checkpoint_sharding_axis")) + # No sharding (replicated) + s3 = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)) def checkpoint_device_put(arr): + """Determines correct sharding spec based on shape and shards the input array. + + Args: + arr: A numpy array (or jax array). + + Returns: + A sharded jax array. + """ + if not isinstance(arr, (np.ndarray, jax.Array)): + # materialize lazy tensor + arr = np.array(arr) + if arr.shape[0] % device_count == 0: - max_logging.log("sharding first axis") + max_logging.log("sharding axis 0") return jax.device_put(arr, device=s1) elif len(arr.shape) > 1 and arr.shape[1] % device_count == 0: - max_logging.log("sharding second axis") + max_logging.log("sharding axis 1") return jax.device_put(arr, device=s2) else: max_logging.log("no sharding was possible, replicating") return jax.device_put(arr, device=s3) + # Weight sharding + start = time.time() # convert all weights to jax.numpy with sharding if applicable jax_weights_flat, jax_weights_struct = tree.flatten(jax_weights) + del jax_weights + gc.collect() + jax_weights_new = [] - while len(jax_weights_flat) > 0: - jax_weight = jax_weights_flat.pop(0) + jax_weights_flat.reverse() + num_weights = len(jax_weights_flat) + for _ in tqdm(range(num_weights)): + jax_weight = jax_weights_flat.pop() jax_weights_new.append(checkpoint_device_put(jax_weight)) del jax_weight gc.collect() logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) jax_weights = tree.unflatten(jax_weights_struct, jax_weights_new) + max_logging.log(f"Elapse for checkpoint sharding: {(time.time() - start) / 60:.2f} min") + + return jax_weights + + +def save_weights_to_checkpoint( + maxtext_model_path: str, + jax_weights: dict, + device_count: int, + use_ocdbt: bool, + use_zarr3: bool, +): + """Saves model weights to a MaxText-compatible checkpoint with optional sharding. + + This function handles the conversion of NumPy weights into sharded JAX arrays + across a specified number of simulated devices. If the device count is 1, + the sharding and JAX conversion steps are skipped. + Args: + maxtext_model_path: The destination directory or URI for the MaxText checkpoint. + jax_weights: A dictionary mapping parameter names to weight arrays (typically NumPy). + device_count: The number of simulated devices to shard across. If 1, weights + are saved in their original format. + use_ocdbt: If True, enables the Optimized Checkpoint Database with Transactions + (OCDBT) format for improved metadata handling. + use_zarr3: If True, uses the Zarr3 storage format for the underlying array data. + """ + mem_info = psutil.Process() + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) + gc.collect() + + # Weight sharding + if device_count > 1: + jax_weights = shard_checkpoint(jax_weights, device_count, mem_info) + else: + # If number of simulated devices is 1, SKIP sharding and SKIP jax conversion. + max_logging.log("Single device: Skip sharding") + + # Save checkpoint + start = time.time() # dummy configs for the checkpoint_manager step_number_to_save_new_ckpt = 0 enable_checkpointing = True @@ -1703,6 +1766,8 @@ def checkpoint_device_put(arr): # Upon preemption, exit when and only when all ongoing saves are complete. checkpoint_manager.wait_until_finished() + max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min") + def list_folders_pathlib(directory: str): """Lists folders in a directory using pathlib module.