Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions docs/guides/checkpointing_solutions/convert_checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ The following models are supported:
| **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ |
| **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ |
| **GPT-OSS** | 20B, 120B | √ | √ | √ | √ |
| **DeepSeek3** | 671B | - | - | √ | - |
| **DeepSeek2** | 16B | √ | √ | √ | √ |
| **DeepSeek3** | 671B | √ | √ | √ | √ |
| **DeepSeek3.2** | 671B | √ | √ | - | - |
| **Qwen3 Next** | 80B | √ | √ | √ | √ |

## Prerequisites
Expand Down Expand Up @@ -60,7 +62,8 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \
skip_jax_distributed_system=true \
checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) \
checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
--lazy_load_tensors=${LAZY_LOAD_TENSORS?}
--lazy_load_tensors=${LAZY_LOAD_TENSORS?} \
--save_dtype=bfloat16
```

You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/items`.
Expand All @@ -74,7 +77,8 @@ You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/i
- `hardware=cpu`: The conversion script runs on a CPU machine.
- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: These storage flags enable McJAX compatibility when set to True (the default). For Pathways, these should be False.
- `--lazy_load_tensors` (Optional): Enables on-demand loading of weights to prevent OOM (Out of Memory) errors. Highly recommended for large models to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
- `--hf_model_path` (Optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
- `--hf_model_path` (Optional): Specifies a customized remote directory or local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
- `--save_dtype` (Optional): Specifies the data type of saved model weights. Default to `bfloat16` to save memory.

## MaxText to Hugging Face

Expand Down Expand Up @@ -118,7 +122,7 @@ python3 -m maxtext.checkpoint_conversion.to_huggingface \
- `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
- `hardware=cpu`: The conversion script runs on a CPU machine.
- `base_output_directory`: The path where the converted checkpoint will be stored; it can be Google Cloud Storage (GCS), Hugging Face Hub or local.
- `weight_dtype`: dtype for MaxText weights. It affects the resulting Hugging Face weight dtype. Default value is `float32`. We recommend using `bfloat16` to save memory and speed up conversion.
- `weight_dtype`: It affects the resulting Hugging Face weight dtype. Default value is `float32`. We recommend using `bfloat16` to save memory and speed up conversion.

## Verifying conversion correctness

Expand Down Expand Up @@ -226,7 +230,7 @@ To extend conversion support to a new model architecture, you must define its sp

- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.

2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
2. **Add Hugging Face weights Shape**: In [`utils/globals.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.

3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`.

Expand Down
12 changes: 4 additions & 8 deletions src/maxtext/checkpoint_conversion/compare_hf_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from safetensors import safe_open

from maxtext.configs import pyconfig
from maxtext.checkpoint_conversion.utils.utils import print_ram_usage, get_hf_model
from maxtext.checkpoint_conversion.utils.utils import print_ram_usage, load_hf_dict_from_transformers
from maxtext.utils import max_logging
from maxtext.utils.globals import HF_IDS

Expand Down Expand Up @@ -135,8 +135,7 @@ def get_hf_model_state_dict(model_id: str, token: str) -> Dict[str, np.ndarray]:
"""Loads the HuggingFace model state dict and converts to numpy."""
max_logging.log(f"Loading reference model from HuggingFace: {model_id}...")

hf_model = get_hf_model(model_id, token)
state_dict = hf_model.state_dict()
state_dict = load_hf_dict_from_transformers(model_id, token)
numpy_state_dict = {k: v.numpy() for k, v in state_dict.items()}

return numpy_state_dict
Expand Down Expand Up @@ -261,12 +260,9 @@ def main(args: Sequence[str], test_args: argparse.Namespace) -> None:
help="Absolute tolerance for numpy.allclose",
)

local_args, _ = parser.parse_known_args()
logging.set_verbosity(logging.INFO)

# Filter args for MaxText config parsing
model_args = sys.argv
to_remove_args = ["--candidate_path", "--reference_path", "--max_workers", "--rtol", "--atol"]
model_args = [s for s in model_args if not any(s.startswith(a) for a in to_remove_args)]
local_args, remaining_args = parser.parse_known_args()
model_args = [sys.argv[0]] + remaining_args

main(model_args, local_args)
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,8 @@ def shard_checkpoint(jax_weights, device_count, mem_info):
"WARNING: hardware/simulated device mismatch. "
f"Actual JAX devices: {len(jax.devices())}, Requested count: {device_count}."
)
max_logging.log(f"shard weights across {len(jax.devices())} devices")
max_logging.log(f"Shard weights across {len(jax.devices())} devices")
max_logging.log("Note: Axis 0 sharding is the default and will not be logged individually.")
# Pre-define sharding specs
mesh = jax.sharding.Mesh(jax.devices(), "checkpoint_sharding_axis")
# Sharding along axis 0
Expand All @@ -1673,13 +1674,13 @@ def checkpoint_device_put(arr):
arr = np.array(arr)

if arr.shape[0] % device_count == 0:
max_logging.log("sharding axis 0")
# Sharding axis 0: Omit log for brevity per the summary log above.
return jax.device_put(arr, device=s1)
elif len(arr.shape) > 1 and arr.shape[1] % device_count == 0:
max_logging.log("sharding axis 1")
max_logging.log(f"Sharding axis 1. Tensor shape {arr.shape}")
return jax.device_put(arr, device=s2)
else:
max_logging.log("no sharding was possible, replicating")
max_logging.log(f"Not sharding. Tensor shape {arr.shape}")
return jax.device_put(arr, device=s3)

# Weight sharding
Expand Down
20 changes: 19 additions & 1 deletion src/maxtext/checkpoint_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
Defaults to "./mt_output/".
scan_layers: (bool) Whether the MaxText model was trained with scanned layers.
This must match the training configuration of the checkpoint.
weight_dtype: (Optional) It affects the resulting Hugging Face weight dtype.
Default value is `float32`. We recommend using `bfloat16`
to save memory and speed up conversion.

Optional Flags:
--override_model_architecture: If set, overrides the HF model configuration
Expand Down Expand Up @@ -139,13 +142,25 @@ def _validate_or_update_architecture(hf_config, max_config, override: bool):
attributes_to_check = [
("num_attention_heads", "num_query_heads"),
("num_key_value_heads", "num_kv_heads"),
("head_dim", "head_dim"),
("hidden_size", "emb_dim"),
("intermediate_size", "mlp_dim"),
("num_hidden_layers", "num_decoder_layers"),
("vocab_size", "vocab_size"),
]

if max_config.attention_type == "mla":
attributes_to_check.extend(
[
("qk_nope_head_dim", "qk_nope_head_dim"),
("qk_rope_head_dim", "qk_rope_head_dim"),
("v_head_dim", "v_head_dim"),
("kv_lora_rank", "kv_lora_rank"),
("q_lora_rank", "q_lora_rank"),
]
)
else:
attributes_to_check.append(("head_dim", "head_dim"))

mismatches = []

for hf_attr, mt_attr in attributes_to_check:
Expand Down Expand Up @@ -215,6 +230,7 @@ def main(argv: Sequence[str]) -> None:
checkpoint_dict = load_orbax_checkpoint(config)
max_logging.log(f"Elapse for checkpoint load: {(time.time() - start) / 60:.2f} min")

# Define output directory
if not config.base_output_directory:
output_directory = f"tmp/{config.run_name}"
else:
Expand Down Expand Up @@ -269,6 +285,8 @@ def main(argv: Sequence[str]) -> None:
processed_params = process_maxtext_param(key, weight, param_map, hook_fn_map, shape_map, config)
processed_params_list.extend(processed_params)

max_logging.log(f"Weight dtype after transform: {type(processed_params[0][1].dtype)}")

transformed_hf_weights = dict(processed_params_list)
max_logging.log(f"Elapse for transform: {(time.time() - start) / 60:.2f} min")

Expand Down
Loading
Loading