Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
- `checkpoint_storage_use_zarr3`: # Set to True to use zarr3 format (recommended for McJAX); set to False for Pathways.
- `checkpoint_storage_use_ocdbt`: # Set to True to use OCDBT format (recommended for McJAX); set to False for Pathways.
- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. For large models, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
- `--hf_model_path` (optional): Specifies a local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/utils/ckpt_conversion/utils/utils.py#L58-L85) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/0d909c44391539db4e8cc2a33de9d77a891beb31/src/MaxText/utils/ckpt_conversion/utils/utils.py#L58-L85) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.

Above command will download the Hugging Face model to local machine, convert it to the MaxText format and save it to `${MODEL_CHECKPOINT_DIRECTORY}/0/items`.

Expand Down
12 changes: 7 additions & 5 deletions src/MaxText/utils/ckpt_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
import os
from typing import Sequence
import time
from tqdm import tqdm

from transformers import AutoTokenizer, AutoProcessor

Expand All @@ -77,11 +76,10 @@
load_orbax_checkpoint,
detect_and_extract_checkpoint,
HF_IDS,
MemoryMonitorTqdm,
print_peak_memory,
)

os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"


def _get_model_mappings(
model_name: str, scan_layers: bool, hf_config_dict: dict, maxtext_config: pyconfig.HyperParameters
Expand Down Expand Up @@ -125,6 +123,9 @@ def main(argv: Sequence[str]) -> None:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

jax.config.update("jax_platforms", "cpu")
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=16"

# Initialize maxtext config
config = pyconfig.initialize(argv)
assert (
Expand Down Expand Up @@ -179,7 +180,7 @@ def main(argv: Sequence[str]) -> None:
start = time.time()
processed_params_list = []

for key in tqdm(filtered_map_keys, total=len(filtered_map_keys)):
for key in MemoryMonitorTqdm(filtered_map_keys, total=len(filtered_map_keys), leave=True):
if isinstance(key, tuple):
# if key is tuple of param names, weight is list of param weights
weight = [maxtext_state_dict[subkey] for subkey in key]
Expand Down Expand Up @@ -210,6 +211,7 @@ def main(argv: Sequence[str]) -> None:
max_logging.log(f"✅ MaxText model successfully saved in HuggingFace format at {output_directory}")
max_logging.log(f"Elapse for save: {(time.time() - start) / 60:.2f} min")
max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min")
print_peak_memory()


if __name__ == "__main__":
Expand Down
140 changes: 62 additions & 78 deletions src/MaxText/utils/ckpt_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
lazy_load: (bool) If True, uses an on-demand loading strategy to minimize RAM
usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM
Defaults to False.
--hf_model_path: (Optional) Specifies a local or remote directory containing the model weights.
If unspecified, we use the default Hugging Face repository ID
(e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `utils/ckpt_conversion/utils`).
This is necessary for locally dequantized models like GPT-OSS or DeepSeek.

Environment Variables:
HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to
Expand Down Expand Up @@ -63,66 +67,38 @@
from functools import partial
from typing import Sequence, List, Any, Callable
import numpy as np
import jax
import psutil
from flax.training import train_state
import flax.linen as nn
import absl

from transformers import AutoConfig
from tqdm import tqdm
from huggingface_hub import hf_hub_download, list_repo_files
from safetensors import safe_open
import absl

import jax
import flax.linen as nn
from orbax.checkpoint import type_handlers

from MaxText import max_logging
from MaxText import max_utils
from MaxText import maxtext_utils
from MaxText import pyconfig
from MaxText.common_types import MODEL_MODE_TRAIN
from MaxText.layers import models, quantizations
from MaxText.utils.ckpt_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint
from MaxText.utils.ckpt_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
from MaxText.utils.ckpt_conversion.utils.utils import apply_hook_fns, HF_IDS, print_ram_usage, get_hf_model, validate_and_filter_param_map_keys
from MaxText.utils.ckpt_conversion.utils.utils import (
apply_hook_fns,
HF_IDS,
get_hf_model,
validate_and_filter_param_map_keys,
MemoryMonitorTqdm,
print_ram_usage,
print_peak_memory,
)
from maxtext.inference.inference_utils import str2bool
from maxtext.common import checkpointing

jax.config.update("jax_platform_name", "cpu")

absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log


class MemoryMonitorTqdm(tqdm):
"""Custom tqdm class that displays memory usage in the progress bar."""

def format_meter(
self,
n,
total,
elapsed,
postfix=None,
**extra_kwargs,
):
"""Override to add memory usage info to the postfix."""
# Get memory info
memory = psutil.virtual_memory()
used_gb = memory.used / (1024**3)
total_gb = memory.total / (1024**3)
memory_percent = memory.percent

# Create memory postfix
memory_info = f"RAM: {used_gb:.1f}/{total_gb:.1f}GB ({memory_percent:.1f}%)"

# Add memory info to postfix
if postfix:
if isinstance(postfix, dict):
postfix["memory"] = memory_info
else:
postfix = f"{postfix}, {memory_info}"
else:
postfix = memory_info

return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs)


class LazyHFLoader:
"""
Loads Hugging Face weights on-demand to minimize RAM usage.
Expand Down Expand Up @@ -654,20 +630,12 @@ def _eager_getter(key):
hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config, config.scan_layers, saving_to_hf=False)
max_logging.log("Parameter mappings and hooks obtained.")

checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
output_directory,
enable_checkpointing=True,
use_async=False, # Synchronous saving for simplicity in conversion script
save_interval_steps=1, # Save at step 0
use_ocdbt=config.checkpoint_storage_use_ocdbt,
use_zarr3=config.checkpoint_storage_use_zarr3,
)

maxtext_abstract_dict, abstract_params_treedef = get_maxtext_model_info(config)

# Weight transformation
max_logging.log("Starting weight transformation...")
start = time.time()
# Stores MaxText weights: numpy.ndarray
final_mt_weights = [None] * len(maxtext_abstract_dict)

# Preprocess key
Expand All @@ -676,7 +644,7 @@ def _eager_getter(key):
for mt_param_key_or_keys in MemoryMonitorTqdm(
filtered_map_keys, desc="Transforming weights", unit="param", leave=True, dynamic_ncols=True
):
if not use_lazy_load and config.scan_layers:
if not use_lazy_load:
max_logging.log(f"maxtext param: {mt_param_key_or_keys}")

hf_source_keys_or_key = param_map_mt_to_hf.get(mt_param_key_or_keys)
Expand Down Expand Up @@ -715,37 +683,34 @@ def _eager_getter(key):
jax_weights = jax.tree_util.tree_unflatten(abstract_params_treedef, final_mt_weights)
del final_mt_weights, abstract_params_treedef

# Create TrainState for saving.
final_params_for_state = {"params": jax_weights}
final_save_state = train_state.TrainState(step=0, apply_fn=None, params=final_params_for_state, tx=None, opt_state={})
del final_params_for_state

print_ram_usage("Before saving")
start = time.time()
if checkpoint_manager is not None:
if use_lazy_load:
max_logging.log("Starting checkpoint save (loading weights just-in-time)...")
else:
max_logging.log("Starting checkpoint save...")

if checkpointing.save_checkpoint(checkpoint_manager, 0, final_save_state):
max_logging.log("saved a checkpoint at step 0")
if use_lazy_load:
max_logging.log("Starting checkpoint save (loading weights just-in-time)...")
else:
max_logging.log("Starting checkpoint save...")

# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(0):
checkpoint_manager.wait_until_finished()
sys.exit()
# Save the converted weights to a MaxText checkpoint.
# If simulated_cpu_devices_count > 1, weights are promoted from NumPy to JAX arrays
# and sharded across virtual devices.
save_weights_to_checkpoint(
output_directory,
jax_weights,
test_args.simulated_cpu_devices_count,
config.checkpoint_storage_use_ocdbt,
config.checkpoint_storage_use_zarr3,
)

print_ram_usage("Program Ends")
max_logging.log(f"Conversion complete. Checkpoint saved to {output_directory}")
max_logging.log(f"Elapse for save: {(time.time() - start) / 60:.2f} min")
max_logging.log(f"Overall Elapse: {(time.time() - overall_start) / 60:.2f} min")
print_peak_memory()


if __name__ == "__main__":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Suppress TensorFlow logging

# Define local parser
parser = argparse.ArgumentParser()
parser.add_argument(
"--lazy_load_tensors",
Expand All @@ -754,13 +719,32 @@ def _eager_getter(key):
default=False,
help="Whether to use lazy loading of HF tensors.",
)
# if not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name]
# If not specified, default to MaxText.utils.ckpt_conversion.utils.utils.HF_IDS[model_name]
parser.add_argument(
"--hf_model_path", type=str, required=False, default="", help="local path to hf model, or custom remote hf repo"
)
local_args, _ = parser.parse_known_args()
model_args = sys.argv
to_remove_args = ["--lazy_load_tensors", "--hf_model_path"]
for a in to_remove_args:
model_args = [s for s in model_args if not s.startswith(a)]
# Determines the logical sharding of the output checkpoint by partitioning
# weights across virtual XLA devices.
# - Even on a single CPU host, JAX can simulate multiple devices (e.g., 16)
# - If set to 1, sharding is skipped.
# - Sharding is preferred. For downstream loading on TPU pods, this helps prevent OOM and speedup.
#
# Example: Embedding Layer shape=(151936, 1024)
# Case 1: simulated_cpu_devices_count=16 (Sharded)
# sharding: NamedShardingMetadata(shape=[16], ...)
# storage: chunk_shape=(9496, 1024) <-- 1/16th of rows per chunk
# Case 2: simulated_cpu_devices_count=1 (Monolith)
# sharding: None
# storage: chunk_shape=(151936, 1024) <-- Full layer in one chunk
parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some explanation about this flag, and why default is set to 16 here? If we run a conversion on a single-cpu device but set this number higher than 1, what will happen? Thanks for the clarification!

Copy link
Collaborator Author

@shuningjin shuningjin Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for review! I have added detailed comment to explain this flag. To answer your question,

why default is set to 16

  • Most of the previous conversion script use shard=16. See: llama & mixtral, deepseek & kimi, as well as gpt-oss, qwen3-moe.
  • Since this has been used and tested for a long term, I think might be good to follow.
  • In particular, all MoEs and largest models (deepseek 671b, kimi 1T) has been converted this way.

If we run a conversion on a single-cpu device but set this number higher than 1, what will happen

By setting these flags, JAX can simulate multiple devices, even on a single CPU host.

  jax.config.update("jax_platforms", "cpu")
  os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"

In particular, len(jax.devices()) will be equal to this count, and the simulated devices are used to create mesh to shard weights on.

  # Example: Embedding Layer shape=(151936, 1024)
  # Case 1: simulated_cpu_devices_count=16 (Sharded)
  #   sharding: NamedShardingMetadata(shape=[16], ...)
  #   storage:  chunk_shape=(9496, 1024)  <-- 1/16th of rows per chunk
  # Case 2: simulated_cpu_devices_count=1 (Monolith)
  #   sharding: None
  #   storage:  chunk_shape=(151936, 1024) <-- Full layer in one chunk


# Parse local arguments
# Parse known args returns the namespace AND the list of remaining arguments
local_args, remaining_args = parser.parse_known_args()
# Reconstruct model_args (script name + the args MaxText needs)
model_args = [sys.argv[0]] + remaining_args

# Set jax environment
jax.config.update("jax_platforms", "cpu")
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"
main(model_args, local_args)
47 changes: 45 additions & 2 deletions src/MaxText/utils/ckpt_conversion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import json
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from tqdm import tqdm
import resource

import jax
from jax.experimental import multihost_utils
Expand Down Expand Up @@ -753,6 +755,45 @@ def print_ram_usage(stage=""):
)


def print_peak_memory():
# Returns peak usage in Kilobytes on Linux
peak_memory_kb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
max_logging.log(f"Peak Memory: {peak_memory_kb / 1024**2:.2f} GB")


class MemoryMonitorTqdm(tqdm):
"""Custom tqdm class that displays memory usage in the progress bar."""

def format_meter(
self,
n,
total,
elapsed,
postfix=None,
**extra_kwargs,
):
"""Override to add memory usage info to the postfix."""
# Get memory info
memory = psutil.virtual_memory()
used_gb = memory.used / (1024**3)
total_gb = memory.total / (1024**3)
memory_percent = memory.percent

# Create memory postfix
memory_info = f"RAM: {used_gb:.1f}/{total_gb:.1f}GB ({memory_percent:.1f}%)"

# Add memory info to postfix
if postfix:
if isinstance(postfix, dict):
postfix["memory"] = memory_info
else:
postfix = f"{postfix}, {memory_info}"
else:
postfix = memory_info

return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs)


def load_orbax_checkpoint(config) -> dict:
"""Loads a full Orbax checkpoint from disk with unsharded arrays.

Expand Down Expand Up @@ -898,7 +939,9 @@ def get_hf_model(model_id: str, token: str):
if model_id in ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]:
from transformers import Qwen3OmniMoeForConditionalGeneration # pylint: disable=import-outside-toplevel

hf_model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(model_id, token=token)
model_class = Qwen3OmniMoeForConditionalGeneration.from_pretrained
else:
hf_model = AutoModelForCausalLM.from_pretrained(model_id, token=token)
model_class = AutoModelForCausalLM

hf_model = model_class.from_pretrained(model_id, token=token)
return hf_model
Loading
Loading