Skip to content

Commit 5702326

Browse files
committed
checkpoint utility: optimize to_maxtext, add deepseek
1 parent 9f6b09a commit 5702326

14 files changed

Lines changed: 465 additions & 183 deletions

File tree

docs/guides/checkpointing_solutions/convert_checkpoint.md

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ The following models are supported:
1616
| **Qwen3 MoE** | 30B, 235B, 480B |||||
1717
| **Mixtral** | 8x7B, 8x22B |||||
1818
| **GPT-OSS** | 20B, 120B |||||
19-
| **DeepSeek3** | 671B | - | - || - |
19+
| **DeepSeek2** | 16B |||||
20+
| **DeepSeek3** | 671B |||||
21+
| **DeepSeek3.2** | 671B ||| - | - |
2022
| **Qwen3 Next** | 80B |||||
2123

2224
## Prerequisites
@@ -73,24 +75,26 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \
7375
model_name=${MODEL_NAME?} \
7476
hf_access_token=${HF_TOKEN?} \
7577
base_output_directory=${MODEL_CHECKPOINT_DIRECTORY?} \
76-
scan_layers=True \
78+
scan_layers=true \
7779
use_multimodal=false \
7880
hardware=cpu \
7981
skip_jax_distributed_system=true \
8082
checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) \
8183
checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
82-
--lazy_load_tensors=${LAZY_LOAD_TENSORS?}
84+
--lazy_load_tensors=${LAZY_LOAD_TENSORS?} \
85+
--save_dtype=bfloat16
8386
```
8487

8588
- `model_name`: The model identifier, which should be defined in `src/maxtext/configs/types.py`.
8689
- `scan_layers`: Indicates if the output checkpoint is [scanned](https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/reference/core_concepts/checkpoints.md) (scan_layers=true) or unscanned (scan_layers=false).
8790
- `use_multimodal`: Indicates if multimodality is used, important for Gemma3.
8891
- `hf_access_token`: Your Hugging Face token.
89-
- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `Maxtext/tmp`.
92+
- `base_output_directory`: The path where the converted Orbax checkpoint will be stored; it can be Googld Cloud Storage (GCS) or local. If not set, the default output directory is `maxtext/tmp`.
9093
- `hardware=cpu`: run the conversion script on a CPU machine.
9194
- `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: Set to True for McJAX (default, `USE_PATHWAYS=0`); set to False for Pathways (`USE_PATHWAYS=1`). Both are controlled by the `$((1 - USE_PATHWAYS))` calculation in the example above.
92-
- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes.
93-
- `--hf_model_path` (optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
95+
- `--lazy_load_tensors` (optional): If `true`, loads Hugging Face weights on-demand to minimize RAM usage. When memory is constrained, it is recommended to use the `--lazy_load_tensors=true` flag to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with lazy loading uses around 200GB of RAM and completes in ~10 minutes.
96+
- `--hf_model_path` (optional): Specifies a customized remote directory or local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek.
97+
- `--save_dtype` (optional): Specifies the dtype of saved model weights. Default to `bfloat16` to save memory.
9498

9599
Above command will download the Hugging Face model to local machine if `hf_model_path` is unspecified, or reuse the checkpoint in `hf_model_path`. It will convert the checkpoint to the MaxText format and save it to `${MODEL_CHECKPOINT_DIRECTORY}/0/items`.
96100

@@ -217,7 +221,7 @@ To extend conversion support to a new model architecture, you must define its sp
217221
- In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer.
218222

219223
2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion.
220-
3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`.
224+
3. **Register model key**: In [`utils/globals.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`.
221225
4. **Add transformer config**: In [`utils/hf_model_configs.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/checkpoint_conversion/utils/hf_model_configs.py), add the `transformers.Config` object, describing the Hugging Face model configuration (defined in [`src/maxtext/configs/models`](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/maxtext/configs/models)). **Note**: This configuration must precisely match the MaxText model's architecture.
222226

223227
Here is an example [PR to add support for gemma3 multi-modal model](https://github.com/AI-Hypercomputer/maxtext/pull/1983)

src/maxtext/checkpoint_conversion/compare_hf_ckpt.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
from safetensors import safe_open
4949

5050
from maxtext.configs import pyconfig
51-
from maxtext.checkpoint_conversion.utils.utils import print_ram_usage, get_hf_model
51+
from maxtext.checkpoint_conversion.utils.utils import print_ram_usage, load_hf_dict_from_transformers
5252
from maxtext.utils import max_logging
5353
from maxtext.utils.globals import HF_IDS
5454

@@ -135,8 +135,7 @@ def get_hf_model_state_dict(model_id: str, token: str) -> Dict[str, np.ndarray]:
135135
"""Loads the HuggingFace model state dict and converts to numpy."""
136136
max_logging.log(f"Loading reference model from HuggingFace: {model_id}...")
137137

138-
hf_model = get_hf_model(model_id, token)
139-
state_dict = hf_model.state_dict()
138+
state_dict = load_hf_dict_from_transformers(model_id, token)
140139
numpy_state_dict = {k: v.numpy() for k, v in state_dict.items()}
141140

142141
return numpy_state_dict
@@ -261,12 +260,9 @@ def main(args: Sequence[str], test_args: argparse.Namespace) -> None:
261260
help="Absolute tolerance for numpy.allclose",
262261
)
263262

264-
local_args, _ = parser.parse_known_args()
265263
logging.set_verbosity(logging.INFO)
266264

267-
# Filter args for MaxText config parsing
268-
model_args = sys.argv
269-
to_remove_args = ["--candidate_path", "--reference_path", "--max_workers", "--rtol", "--atol"]
270-
model_args = [s for s in model_args if not any(s.startswith(a) for a in to_remove_args)]
265+
local_args, remaining_args = parser.parse_known_args()
266+
model_args = [sys.argv[0]] + remaining_args
271267

272268
main(model_args, local_args)

src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1649,7 +1649,8 @@ def shard_checkpoint(jax_weights, device_count, mem_info):
16491649
"WARNING: hardware/simulated device mismatch. "
16501650
f"Actual JAX devices: {len(jax.devices())}, Requested count: {device_count}."
16511651
)
1652-
max_logging.log(f"shard weights across {len(jax.devices())} devices")
1652+
max_logging.log(f"Shard weights across {len(jax.devices())} devices")
1653+
max_logging.log("Note: Axis 0 sharding is the default and will not be logged individually.")
16531654
# Pre-define sharding specs
16541655
mesh = jax.sharding.Mesh(jax.devices(), "checkpoint_sharding_axis")
16551656
# Sharding along axis 0
@@ -1673,7 +1674,7 @@ def checkpoint_device_put(arr):
16731674
arr = np.array(arr)
16741675

16751676
if arr.shape[0] % device_count == 0:
1676-
max_logging.log("sharding axis 0")
1677+
# Sharding axis 0: Omit log for brevity per the summary log above.
16771678
return jax.device_put(arr, device=s1)
16781679
elif len(arr.shape) > 1 and arr.shape[1] % device_count == 0:
16791680
max_logging.log("sharding axis 1")

src/maxtext/checkpoint_conversion/to_huggingface.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,25 @@ def _validate_or_update_architecture(hf_config, max_config, override: bool):
139139
attributes_to_check = [
140140
("num_attention_heads", "num_query_heads"),
141141
("num_key_value_heads", "num_kv_heads"),
142-
("head_dim", "head_dim"),
143142
("hidden_size", "emb_dim"),
144143
("intermediate_size", "mlp_dim"),
145144
("num_hidden_layers", "num_decoder_layers"),
146145
("vocab_size", "vocab_size"),
147146
]
148147

148+
if max_config.attention_type == "mla":
149+
attributes_to_check.extend(
150+
[
151+
("qk_nope_head_dim", "qk_nope_head_dim"),
152+
("qk_rope_head_dim", "qk_rope_head_dim"),
153+
("v_head_dim", "v_head_dim"),
154+
("kv_lora_rank", "kv_lora_rank"),
155+
("q_lora_rank", "q_lora_rank"),
156+
]
157+
)
158+
else:
159+
attributes_to_check.append(("head_dim", "head_dim"))
160+
149161
mismatches = []
150162

151163
for hf_attr, mt_attr in attributes_to_check:
@@ -269,6 +281,8 @@ def main(argv: Sequence[str]) -> None:
269281
processed_params = process_maxtext_param(key, weight, param_map, hook_fn_map, shape_map, config)
270282
processed_params_list.extend(processed_params)
271283

284+
max_logging.log(f"Weight dtype after transform: {type(processed_params[0][1].dtype)}")
285+
272286
transformed_hf_weights = dict(processed_params_list)
273287
max_logging.log(f"Elapse for transform: {(time.time() - start) / 60:.2f} min")
274288

0 commit comments

Comments
 (0)