diff --git a/docs/guides/checkpointing_solutions/convert_checkpoint.md b/docs/guides/checkpointing_solutions/convert_checkpoint.md index eba5fc7261..60878723f2 100644 --- a/docs/guides/checkpointing_solutions/convert_checkpoint.md +++ b/docs/guides/checkpointing_solutions/convert_checkpoint.md @@ -16,7 +16,9 @@ The following models are supported: | **Qwen3 MoE** | 30B, 235B, 480B | √ | √ | √ | √ | | **Mixtral** | 8x7B, 8x22B | √ | √ | √ | √ | | **GPT-OSS** | 20B, 120B | √ | √ | √ | √ | -| **DeepSeek3** | 671B | - | - | √ | - | +| **DeepSeek2** | 16B | √ | √ | √ | √ | +| **DeepSeek3** | 671B | √ | √ | √ | √ | +| **DeepSeek3.2** | 671B | √ | √ | - | - | | **Qwen3 Next** | 80B | √ | √ | √ | √ | ## Prerequisites @@ -60,7 +62,8 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext \ skip_jax_distributed_system=true \ checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) \ checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \ - --lazy_load_tensors=${LAZY_LOAD_TENSORS?} + --lazy_load_tensors=${LAZY_LOAD_TENSORS?} \ + --save_dtype=bfloat16 ``` You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/items`. @@ -74,7 +77,8 @@ You can find your converted checkpoint files under `${BASE_OUTPUT_DIRECTORY}/0/i - `hardware=cpu`: The conversion script runs on a CPU machine. - `checkpoint_storage_use_zarr3` and `checkpoint_storage_use_ocdbt`: These storage flags enable McJAX compatibility when set to True (the default). For Pathways, these should be False. - `--lazy_load_tensors` (Optional): Enables on-demand loading of weights to prevent OOM (Out of Memory) errors. Highly recommended for large models to reduce memory usage during conversion. For example, converting a Llama3.1-70B model with `--lazy_load_tensors=true` uses around 200GB of RAM and completes in ~10 minutes. -- `--hf_model_path` (Optional): Specifies a local or remote directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/utils.py#L59-L91) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. +- `--hf_model_path` (Optional): Specifies a customized remote directory or local directory containing the model weights. If unspecified, we use the [default Hugging Face repository ID](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py) (e.g., openai/gpt-oss-20b). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. +- `--save_dtype` (Optional): Specifies the data type of saved model weights. Default to `bfloat16` to save memory. ## MaxText to Hugging Face @@ -118,7 +122,7 @@ python3 -m maxtext.checkpoint_conversion.to_huggingface \ - `use_multimodal`: Indicates if multimodality is used, important for Gemma3. - `hardware=cpu`: The conversion script runs on a CPU machine. - `base_output_directory`: The path where the converted checkpoint will be stored; it can be Google Cloud Storage (GCS), Hugging Face Hub or local. -- `weight_dtype`: dtype for MaxText weights. It affects the resulting Hugging Face weight dtype. Default value is `float32`. We recommend using `bfloat16` to save memory and speed up conversion. +- `weight_dtype`: It affects the resulting Hugging Face weight dtype. Default value is `float32`. We recommend using `bfloat16` to save memory and speed up conversion. ## Verifying conversion correctness @@ -226,7 +230,7 @@ To extend conversion support to a new model architecture, you must define its sp - In [`utils/param_mapping.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/param_mapping.py), add the `hook_fn` logic (`def {MODEL}_MAXTEXT_TO_HF_PARAM_HOOK_FN`). This is the transformation needed per layer. -2. **Add Hugging Face weights Shape**: In [`utils/hf_shape.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion. +2. **Add Hugging Face weights Shape**: In [`utils/globals.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/checkpoint_conversion/utils/hf_shape.py), define the tensor shape of Hugging Face format (`def {MODEL}_HF_WEIGHTS_TO_SHAPE`). This is used to ensure the tensor shape is matched after to_huggingface conversion. 3. **Register model key**: In [`utils/utils.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/utils/globals.py), add the new model key in `HF_IDS`. diff --git a/src/maxtext/checkpoint_conversion/compare_hf_ckpt.py b/src/maxtext/checkpoint_conversion/compare_hf_ckpt.py index 44b011343a..c320e4ac39 100644 --- a/src/maxtext/checkpoint_conversion/compare_hf_ckpt.py +++ b/src/maxtext/checkpoint_conversion/compare_hf_ckpt.py @@ -48,7 +48,7 @@ from safetensors import safe_open from maxtext.configs import pyconfig -from maxtext.checkpoint_conversion.utils.utils import print_ram_usage, get_hf_model +from maxtext.checkpoint_conversion.utils.utils import print_ram_usage, load_hf_dict_from_transformers from maxtext.utils import max_logging from maxtext.utils.globals import HF_IDS @@ -135,8 +135,7 @@ def get_hf_model_state_dict(model_id: str, token: str) -> Dict[str, np.ndarray]: """Loads the HuggingFace model state dict and converts to numpy.""" max_logging.log(f"Loading reference model from HuggingFace: {model_id}...") - hf_model = get_hf_model(model_id, token) - state_dict = hf_model.state_dict() + state_dict = load_hf_dict_from_transformers(model_id, token) numpy_state_dict = {k: v.numpy() for k, v in state_dict.items()} return numpy_state_dict @@ -261,12 +260,9 @@ def main(args: Sequence[str], test_args: argparse.Namespace) -> None: help="Absolute tolerance for numpy.allclose", ) - local_args, _ = parser.parse_known_args() logging.set_verbosity(logging.INFO) - # Filter args for MaxText config parsing - model_args = sys.argv - to_remove_args = ["--candidate_path", "--reference_path", "--max_workers", "--rtol", "--atol"] - model_args = [s for s in model_args if not any(s.startswith(a) for a in to_remove_args)] + local_args, remaining_args = parser.parse_known_args() + model_args = [sys.argv[0]] + remaining_args main(model_args, local_args) diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py b/src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py index fe2b651d5b..56f395eed9 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py @@ -1649,7 +1649,8 @@ def shard_checkpoint(jax_weights, device_count, mem_info): "WARNING: hardware/simulated device mismatch. " f"Actual JAX devices: {len(jax.devices())}, Requested count: {device_count}." ) - max_logging.log(f"shard weights across {len(jax.devices())} devices") + max_logging.log(f"Shard weights across {len(jax.devices())} devices") + max_logging.log("Note: Axis 0 sharding is the default and will not be logged individually.") # Pre-define sharding specs mesh = jax.sharding.Mesh(jax.devices(), "checkpoint_sharding_axis") # Sharding along axis 0 @@ -1673,13 +1674,13 @@ def checkpoint_device_put(arr): arr = np.array(arr) if arr.shape[0] % device_count == 0: - max_logging.log("sharding axis 0") + # Sharding axis 0: Omit log for brevity per the summary log above. return jax.device_put(arr, device=s1) elif len(arr.shape) > 1 and arr.shape[1] % device_count == 0: - max_logging.log("sharding axis 1") + max_logging.log(f"Sharding axis 1. Tensor shape {arr.shape}") return jax.device_put(arr, device=s2) else: - max_logging.log("no sharding was possible, replicating") + max_logging.log(f"Not sharding. Tensor shape {arr.shape}") return jax.device_put(arr, device=s3) # Weight sharding diff --git a/src/maxtext/checkpoint_conversion/to_huggingface.py b/src/maxtext/checkpoint_conversion/to_huggingface.py index 94c70860c7..94d808d162 100644 --- a/src/maxtext/checkpoint_conversion/to_huggingface.py +++ b/src/maxtext/checkpoint_conversion/to_huggingface.py @@ -28,6 +28,9 @@ Defaults to "./mt_output/". scan_layers: (bool) Whether the MaxText model was trained with scanned layers. This must match the training configuration of the checkpoint. + weight_dtype: (Optional) It affects the resulting Hugging Face weight dtype. + Default value is `float32`. We recommend using `bfloat16` + to save memory and speed up conversion. Optional Flags: --override_model_architecture: If set, overrides the HF model configuration @@ -139,13 +142,25 @@ def _validate_or_update_architecture(hf_config, max_config, override: bool): attributes_to_check = [ ("num_attention_heads", "num_query_heads"), ("num_key_value_heads", "num_kv_heads"), - ("head_dim", "head_dim"), ("hidden_size", "emb_dim"), ("intermediate_size", "mlp_dim"), ("num_hidden_layers", "num_decoder_layers"), ("vocab_size", "vocab_size"), ] + if max_config.attention_type == "mla": + attributes_to_check.extend( + [ + ("qk_nope_head_dim", "qk_nope_head_dim"), + ("qk_rope_head_dim", "qk_rope_head_dim"), + ("v_head_dim", "v_head_dim"), + ("kv_lora_rank", "kv_lora_rank"), + ("q_lora_rank", "q_lora_rank"), + ] + ) + else: + attributes_to_check.append(("head_dim", "head_dim")) + mismatches = [] for hf_attr, mt_attr in attributes_to_check: @@ -215,6 +230,7 @@ def main(argv: Sequence[str]) -> None: checkpoint_dict = load_orbax_checkpoint(config) max_logging.log(f"Elapse for checkpoint load: {(time.time() - start) / 60:.2f} min") + # Define output directory if not config.base_output_directory: output_directory = f"tmp/{config.run_name}" else: @@ -269,6 +285,8 @@ def main(argv: Sequence[str]) -> None: processed_params = process_maxtext_param(key, weight, param_map, hook_fn_map, shape_map, config) processed_params_list.extend(processed_params) + max_logging.log(f"Weight dtype after transform: {type(processed_params[0][1].dtype)}") + transformed_hf_weights = dict(processed_params_list) max_logging.log(f"Elapse for transform: {(time.time() - start) / 60:.2f} min") diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index a77893df2f..df5393d04a 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -25,13 +25,15 @@ Defaults to "./mt_output/". scan_layers: (bool) Whether the MaxText model was trained with scanned layers. This must match the training configuration of the checkpoint. - lazy_load: (bool) If True, uses an on-demand loading strategy to minimize RAM + --lazy_load_tensors: (bool) If True, uses an on-demand loading strategy to minimize RAM usage during conversion. Recommended if, 2 * model_size (GB) >= system RAM Defaults to False. --hf_model_path: (Optional) Specifies a local or remote directory containing the model weights. If unspecified, we use the default Hugging Face repository ID (e.g., openai/gpt-oss-20b; see `HF_IDS[model_name]` in `maxtext.utils.globals`). This is necessary for locally dequantized models like GPT-OSS or DeepSeek. + --save_dtype: (Optional) Specifies the data type of saved model weights. + Default to `bfloat16` to save memory. Environment Variables: HF_AUTH_TOKEN: (Required) HuggingFace authentication token, needed to @@ -40,7 +42,7 @@ Example Usage: To convert a gemma2-2b model and save it to a specific directory: - /usr/bin/time -v python src/maxtext/checkpoint_conversion/to_maxtext.py \ + python -m maxtext.checkpoint_conversion.to_maxtext \ maxtext/configs/base.yml model_name="gemma2-2b" \ base_output_directory="/path/to/your/output/directory" \ hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True \ @@ -51,7 +53,7 @@ To convert a 70B model with minimal RAM usage: - /usr/bin/time -v python src/maxtext/checkpoint_conversion/to_maxtext.py \ + python -m maxtext.checkpoint_conversion.to_maxtext \ maxtext/configs/base.yml model_name="llama3.1-70b" \ base_output_directory="gs://my-bucket/maxtext-checkpoints" \ hf_access_token=${HF_TOKEN?} hardware=cpu skip_jax_distributed_system=True \ @@ -67,14 +69,18 @@ import time from typing import Any, Callable, List, Sequence import absl +import ml_dtypes +import torch import flax.linen as nn from huggingface_hub import hf_hub_download, list_repo_files import jax from maxtext.configs import pyconfig +from maxtext.configs.types import DType from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.checkpoint_conversion.standalone_scripts.llama_or_mistral_ckpt import save_weights_to_checkpoint +from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING -from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, get_hf_model, print_peak_memory, print_ram_usage, validate_and_filter_param_map_keys +from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, validate_and_filter_param_map_keys from maxtext.inference.inference_utils import str2bool from maxtext.layers import quantizations from maxtext.models import models @@ -83,7 +89,6 @@ import numpy as np from orbax.checkpoint import type_handlers from safetensors import safe_open -from transformers import AutoConfig absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log @@ -498,7 +503,7 @@ def _get_maxtext_weight( mt_target_shape_or_shapes, mt_param_key_or_keys, final_mt_weights, - config, + save_dtype, use_lazy_load, ): """Loads Hugging Face parameters and converts them to MaxText parameters. @@ -542,7 +547,7 @@ def _get_maxtext_weight( final_mt_tensor_numpy = LazyTensor( load_fn, mt_target_shape_or_shapes, - config.weight_dtype, + save_dtype, name=mt_param_key_or_keys, ) if not is_composite_mt_key: @@ -564,16 +569,18 @@ def _slicing_loader(base_loader, slice_idx): final_mt_weights[mt_target_idx] = LazyTensor( slicing_load_fn, mt_target_shape_or_shapes[i], - config.weight_dtype, + save_dtype, name=mt_param_key_or_keys[i], ) def main( args: Sequence[str], + lazy_load_tensors: bool = False, + eager_load_method: str = "transformers", hf_model_path: str | None = None, revision: str | None = None, - lazy_load_tensors: bool = False, + save_dtype: str = "bfloat16", simulated_cpu_devices_count: int = 16, ) -> None: overall_start = time.time() @@ -618,38 +625,81 @@ def main( if lazy_load_tensors: max_logging.log(f"Lazy loading ENABLED. Initializing LazyHFLoader for: {model_id}...") hf_loader = LazyHFLoader(model_id, hf_token, revision=revision) - hf_config_obj = AutoConfig.from_pretrained(model_id, token=hf_token, revision=revision) + print_ram_usage("After LazyLoader init") tensor_getter = hf_loader.get_tensor else: max_logging.log(f"Lazy loading DISABLED. Loading full HuggingFace model: {model_id}...") - hf_config_obj = AutoConfig.from_pretrained(model_id, token=hf_token, revision=revision) - hf_model = get_hf_model(model_id, token=hf_token, revision=revision) - hf_state_dict_numpy = hf_model.state_dict() - # Convert all to numpy immediately in eager mode - for k, v in hf_state_dict_numpy.items(): - hf_state_dict_numpy[k] = v.numpy() - del hf_model - max_logging.log("HuggingFace model loaded and converted to NumPy.") + + # Eager load methods: + # - Method 1: transformers_class.from_pretrained(..., dtype="auto") + # - Method 2: safetensors.safe_open(..., framework="pt") + # + # Comparison: + # - Both methods result in the same dtype (usually bfloat16) and model structure + # for most models (e.g., DeepSeek-V2), with similar loading times. + # - Exception: Gemma-3 uses different internal naming (prefixes) between + # Method 1 and Method 2. Current MaxText 'param_mapping' for Gemma-3 assumes + # the Transformers-style structure (Method 1). + # - The 'safetensors' method is a necessary fallback for: + # 1. "Day-0" models where the official Transformers code hasn't been merged yet + # (e.g., DeepSeek-V3.2 during its initial release). + # 2. Weights omitted by official Transformers class + # (e.g., Multi-Token Prediction weights (`layers.61`) in DeepSeek-V3). + # + # Recommendation: + # - Use 'transformers' as the default for backward compatibility of mapping. + # - 'safetensors' is an interchangeable and valid alternative for most models, + # and is strictly required if the model or specific weights lack Transformers support. + if eager_load_method == "transformers": + max_logging.log("Eager load with Transformers backend, from_pretrained with auto dtype") + # For auto mode, loaded dtype is the same as `dtype` specified in config.json (or `torch_dtype` for older version) + # e.g., https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/config.json#L54 + hf_state_dict_numpy = load_hf_dict_from_transformers(model_id, token=hf_token, revision=revision, dtype="auto") + elif eager_load_method == "safetensors": + max_logging.log("Eager load with Safetensors backend, safe_open with pt framework") + # For safe_open, loaded dtype is the same as original safetensor + # e.g., https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/model.safetensors.index.json + hf_state_dict_numpy = load_hf_dict_from_safetensors(model_id, token=hf_token, revision=revision, framework="pt") + else: + raise NotImplementedError + + unique_dtypes = {tensor.dtype for tensor in hf_state_dict_numpy.values()} + max_logging.log(f"HuggingFace model loaded. dtypes: {unique_dtypes}") print_ram_usage("After full HF model load") def _eager_getter(key): if key not in hf_state_dict_numpy: raise ValueError(f"HuggingFace key {key} not found in state_dict.") - return hf_state_dict_numpy[key] + v = hf_state_dict_numpy[key] + # target dtype is "float32" + if save_dtype == DType.FLOAT32: + return v.to(torch.float32).numpy() + # target dtype is "bfloat16" + elif save_dtype == DType.BFLOAT16: + # - torch.bfloat16 -> torch.float32 -> np.float32 -> ml_dtypes.bfloat16 + # As numpy doesn't accept bfloat16 directly, we convert to float32 first + # - torch.float16 -> np.float16 -> ml_dtypes.bfloat16 + # - torch.float32 -> np.float32 -> ml_dtypes.bfloat16 + if v.dtype == torch.bfloat16: + v = v.to(torch.float32) + return v.numpy().astype(ml_dtypes.bfloat16) + raise NotImplementedError(f"Save dtype {save_dtype} is not currently implemented.") tensor_getter = _eager_getter # Get parameter mappings and hooks + model_key = config.model_name + # load config + hf_config_obj = HF_MODEL_CONFIGS[model_key] + hf_config_dict = hf_config_obj.to_dict() # example of param mapping (gemma2, maxtext:huggingface): # "params-decoder-layers_{maxtext_layer_idx}-pre_self_attention_norm_global-scale": # f"model.layers.{global_layer_idx}.input_layernorm.weight", - model_key = config.model_name - param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_obj.to_dict(), config, config.scan_layers) - + param_map_mt_to_hf = PARAM_MAPPING[model_key](hf_config_dict, config, config.scan_layers) # Example of Hook FN mapping, to perform reshape: # f"params-decoder-layers_{maxtext_layer_idx}-self_attention_global-key-kernel": reshape_kernel, - hook_fn_map_mt = HOOK_FNS[model_key](hf_config_obj.to_dict(), config, config.scan_layers, saving_to_hf=False) + hook_fn_map_mt = HOOK_FNS[model_key](hf_config_dict, config, config.scan_layers, saving_to_hf=False) max_logging.log("Parameter mappings and hooks obtained.") maxtext_abstract_dict, abstract_params_treedef = get_maxtext_model_info(config) @@ -669,6 +719,7 @@ def _eager_getter(key): unit="param", leave=True, dynamic_ncols=True, + smoothing=0, ): if not lazy_load_tensors: max_logging.log(f"maxtext param: {mt_param_key_or_keys}") @@ -702,7 +753,7 @@ def _eager_getter(key): mt_target_shape_or_shapes, mt_param_key_or_keys, final_mt_weights, - config, + save_dtype, lazy_load_tensors, ) @@ -744,20 +795,51 @@ def _eager_getter(key): # Define local parser parser = argparse.ArgumentParser() + # Lazy load uses `safetensors.safe_open` with np parser.add_argument( "--lazy_load_tensors", type=str2bool, required=False, default=False, - help="Whether to use lazy loading of HF tensors.", + help="Whether to use lazy loading of HF tensors", + ) + # Eager load uses `transformers_class.from_pretrained` with auto dtype or `safetensors.safe_open` with pt. + # The two methods are interchangeable in most cases. + # Must use "transformers" for gemma3-4b due to mapping compatibility. + # Must use "safetensors" for models without official transformers support, like DeepSeek-V3.2. + # Must use "safetensors" for weights omitted by transformers class, + # like Multi-Token Prediction weights (`layers.61`) in DeepSeek-V3. + parser.add_argument( + "--eager_load_method", + type=str, + required=False, + default="transformers", + choices=["transformers", "safetensors"], + help="Backend to use for eager loading: `transformers_class.from_pretrained` or `safetensors.safe_open` with pt", ) # If not specified, default to maxtext.utils.globals.HF_IDS[model_name] parser.add_argument( "--hf_model_path", type=str, required=False, - default="", - help="local path to hf model, or custom remote hf repo", + default=None, + help="Customized remote HF repo, or local path to HF model", + ) + # If hf_model_path is set to a local path, this is ignored. + parser.add_argument( + "--revision", + type=str, + required=False, + default=None, + help="Specific Hugging Face revision (branch/tag/commit)", + ) + parser.add_argument( + "--save_dtype", + type=str, + required=False, + default="bfloat16", + choices=["float32", "bfloat16"], + help="Save MaxText weights in specified dtype", ) # Determines the logical sharding of the output checkpoint by partitioning # weights across virtual XLA devices. @@ -772,16 +854,9 @@ def _eager_getter(key): # Case 2: simulated_cpu_devices_count=1 (Monolith) # sharding: None # storage: chunk_shape=(151936, 1024) <-- Full layer in one chunk - parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16) - parser.add_argument( - "--revision", - type=str, - required=False, - default=None, - help="Specific Hugging Face revision (branch/tag/commit)", + "--simulated_cpu_devices_count", type=int, required=False, default=16, help="Sharding of checkpoint" ) - # Parse local arguments # Parse known args returns the namespace AND the list of remaining arguments local_args, remaining_args = parser.parse_known_args() @@ -793,8 +868,10 @@ def _eager_getter(key): os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}" main( args=model_args, + lazy_load_tensors=local_args.lazy_load_tensors, + eager_load_method=local_args.eager_load_method, hf_model_path=local_args.hf_model_path, revision=local_args.revision, - lazy_load_tensors=local_args.lazy_load_tensors, + save_dtype=local_args.save_dtype, simulated_cpu_devices_count=local_args.simulated_cpu_devices_count, ) diff --git a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py index 6103476d03..d54ced6983 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py @@ -19,6 +19,11 @@ import transformers +if transformers.__version__ >= "5.0.0": + from transformers.configuration_utils import PreTrainedConfig as PTConfig +else: + from transformers.configuration_utils import PretrainedConfig as PTConfig + gemma3_4b_config = transformers.Gemma3Config( architectures=["Gemma3ForConditionalGeneration"], boi_token_index=255999, @@ -520,7 +525,67 @@ vocab_size=151936, ) -# copy from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json +# from https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/config.json +deepseek2_16b_dict = { + "architectures": ["DeepseekV2ForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_deepseek.DeepseekV2Config", + "AutoModel": "modeling_deepseek.DeepseekV2Model", + "AutoModelForCausalLM": "modeling_deepseek.DeepseekV2ForCausalLM", + }, + "aux_loss_alpha": 0.001, + "bos_token_id": 100000, + "eos_token_id": 100001, + "first_k_dense_replace": 1, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 10944, + "kv_lora_rank": 512, + "max_position_embeddings": 163840, + "model_type": "deepseek_v2", + "moe_intermediate_size": 1408, + "moe_layer_freq": 1, + "n_group": 1, + "n_routed_experts": 64, + "n_shared_experts": 2, + "norm_topk_prob": False, + "num_attention_heads": 16, + "num_experts_per_tok": 6, + "num_hidden_layers": 27, + "num_key_value_heads": 16, + "pretraining_tp": 1, + "q_lora_rank": None, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 0.707, + "mscale_all_dim": 0.707, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, + "rope_theta": 10000, + "routed_scaling_factor": 1.0, + "scoring_func": "softmax", + "seq_aux": True, + "tie_word_embeddings": False, + "topk_group": 1, + "topk_method": "greedy", + "torch_dtype": "bfloat16", + "transformers_version": "4.33.1", + "use_cache": True, + "v_head_dim": 128, + "vocab_size": 102400, +} +deepseek2_16b_config = transformers.DeepseekV2Config(**deepseek2_16b_dict) + +# from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json # remove fp8 quantization_config, since we are using bf16 deepseek3_671b_dict = { "architectures": ["DeepseekV3ForCausalLM"], @@ -580,7 +645,66 @@ } deepseek3_671b_config = transformers.DeepseekV3Config(**deepseek3_671b_dict) -# copy from https://huggingface.co/openai/gpt-oss-20b/blob/main/config.json +# from https://huggingface.co/deepseek-ai/DeepSeek-V3.2/blob/main/config.json +# remove fp8 quantization_config, since we are using bf16 +deepseek32_671b_dict = { + "architectures": ["DeepseekV32ForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": 0, + "eos_token_id": 1, + "ep_size": 1, + "first_k_dense_replace": 3, + "hidden_act": "silu", + "hidden_size": 7168, + "index_head_dim": 128, + "index_n_heads": 64, + "index_topk": 2048, + "initializer_range": 0.02, + "intermediate_size": 18432, + "kv_lora_rank": 512, + "max_position_embeddings": 163840, + "model_type": "deepseek_v32", + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "n_group": 8, + "n_routed_experts": 256, + "n_shared_experts": 1, + "norm_topk_prob": True, + "num_attention_heads": 128, + "num_experts_per_tok": 8, + "num_hidden_layers": 61, + "num_key_value_heads": 128, + "num_nextn_predict_layers": 1, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, + "rope_theta": 10000, + "routed_scaling_factor": 2.5, + "scoring_func": "sigmoid", + "tie_word_embeddings": False, + "topk_group": 4, + "topk_method": "noaux_tc", + "torch_dtype": "bfloat16", + "transformers_version": "4.44.2", + "use_cache": True, + "v_head_dim": 128, + "vocab_size": 129280, +} +# TODO(shuningjin): replace with DeepseekV32Config when available in transformers library +deepseek32_671b_config = PTConfig(**deepseek32_671b_dict) + +# from https://huggingface.co/openai/gpt-oss-20b/blob/main/config.json # remove mxfp4 quantization_config, since we are using bf16 gpt_oss_20b_dict = { "architectures": ["GptOssForCausalLM"], @@ -649,7 +773,7 @@ } gpt_oss_20b_config = transformers.GptOssConfig(**gpt_oss_20b_dict) -# copy from https://huggingface.co/openai/gpt-oss-120b/blob/main/config.json +# from https://huggingface.co/openai/gpt-oss-120b/blob/main/config.json # remove mxfp4 quantization_config, since we are using bf16 gpt_oss_120b_dict = { "architectures": ["GptOssForCausalLM"], @@ -887,7 +1011,9 @@ "qwen3-30b-a3b-base": qwen3_30b_a3b_thinking_2507_config, "qwen3-235b-a22b": qwen3_235b_a22b_thinking_2507_config, "qwen3-480b-a35b": qwen3_coder_480b_a35b_config, + "deepseek2-16b": deepseek2_16b_config, "deepseek3-671b": deepseek3_671b_config, + "deepseek3.2-671b": deepseek32_671b_config, "gpt-oss-20b": gpt_oss_20b_config, "gpt-oss-120b": gpt_oss_120b_config, "qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config, diff --git a/src/maxtext/checkpoint_conversion/utils/hf_shape.py b/src/maxtext/checkpoint_conversion/utils/hf_shape.py index d934178c8d..9f964ecef4 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_shape.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_shape.py @@ -209,28 +209,25 @@ def GEMMA2_HF_WEIGHTS_TO_SHAPE(config): def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): - """Returns mapping between HuggingFace DeepseekV3 weights path and their shape. - - This mapping is derived by matching the provided config dictionary against - the model's parameter dump. - - To check this mapping, dump the huggingface model shapes: - from transformers import AutoModelForCausalLM - model_name = "deepseek-ai/DeepSeek-V3" - model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto") - for name, val in model.named_parameters(): - print(name, val.shape) + """Returns mapping between HuggingFace weights path and their shape derived from HF config. Args: - config (dict): Model configuration dictionary (from HF DeepseekV3Config.to_dict()) - Expected keys: https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json + config (dict): HF configuration dictionary + e.g., https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/config.json Returns: dict: A mapping where: - Keys are HuggingFace model parameter paths - Values are parameter shape as a list + + To check expected mapping: + from transformers import AutoModelForCausalLM + model_name = "deepseek-ai/DeepSeek-V2-Lite" + model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto") + for name, val in model.named_parameters(): + print(name, val.shape) """ - # --- Extract Core Config Values --- + # --- Core Config Values --- hidden_size = config["hidden_size"] num_hidden_layers = config["num_hidden_layers"] vocab_size = config["vocab_size"] @@ -240,13 +237,17 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): q_lora_rank = config["q_lora_rank"] kv_lora_rank = config["kv_lora_rank"] num_attention_heads = config["num_attention_heads"] + + # qk_head_dim is present in DeepseekV3Config, but missing from DeepseekV2Config + qk_head_dim = config.get("qk_head_dim", config["qk_nope_head_dim"] + config["qk_rope_head_dim"]) # Q projection dim - q_dim = num_attention_heads * config["qk_head_dim"] - # KV_b_proj output dim. + q_dim = num_attention_heads * qk_head_dim + + # kv_b_proj output dim kv_b_dim = num_attention_heads * (config["qk_nope_head_dim"] + config["v_head_dim"]) # Output projection dim (input) o_proj_in_dim = num_attention_heads * config["v_head_dim"] - # kv_a_proj_with_mqa output dim. + # kv_a_proj_with_mqa output dim kv_a_proj_out_dim = config["kv_lora_rank"] + config["qk_rope_head_dim"] # --- MLP-related Dimensions --- @@ -257,7 +258,11 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): # This key determines which layers are dense vs. MoE first_k_dense = config.get("first_k_dense_replace", 0) - # --- Initialize Mapping --- + # --- Indexer Configuration (Optional) --- + index_head_dim = config.get("index_head_dim") + index_n_heads = config.get("index_n_heads") + + # --- Non-layer-specific weights --- mapping = { "model.embed_tokens.weight": [vocab_size, hidden_size], "model.norm.weight": [hidden_size], @@ -268,18 +273,20 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): for layer_idx in range(num_hidden_layers): layer_prefix = f"model.layers.{layer_idx}" - # Common layer components + # --- Attention weights --- layer_mapping = { + # norm f"{layer_prefix}.input_layernorm.weight": [hidden_size], f"{layer_prefix}.post_attention_layernorm.weight": [hidden_size], - # --- Attention projections --- + # kv projection f"{layer_prefix}.self_attn.kv_a_proj_with_mqa.weight": [kv_a_proj_out_dim, hidden_size], f"{layer_prefix}.self_attn.kv_a_layernorm.weight": [kv_lora_rank], f"{layer_prefix}.self_attn.kv_b_proj.weight": [kv_b_dim, kv_lora_rank], + # output projection f"{layer_prefix}.self_attn.o_proj.weight": [hidden_size, o_proj_in_dim], } - # --- Q-Projection (Conditional on LoRA) --- + # query projection if q_lora_rank is None: layer_mapping[f"{layer_prefix}.self_attn.q_proj.weight"] = [q_dim, hidden_size] else: @@ -291,7 +298,7 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): } ) - # --- Add conditional biases --- + # bias if attention_bias: if q_lora_rank is not None: layer_mapping[f"{layer_prefix}.self_attn.q_a_proj.bias"] = [q_lora_rank] @@ -302,9 +309,23 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): } ) - # --- Add MLP weights (Dense vs. MoE) --- + # indexer for sparse attention + if index_head_dim is not None and index_n_heads is not None and q_lora_rank is not None: + wq_b_dim_out = index_n_heads * index_head_dim + indexer_prefix = f"{layer_prefix}.self_attn.indexer" + layer_mapping.update( + { + f"{indexer_prefix}.k_norm.bias": [index_head_dim], + f"{indexer_prefix}.k_norm.weight": [index_head_dim], + f"{indexer_prefix}.weights_proj.weight": [index_n_heads, hidden_size], + f"{indexer_prefix}.wk.weight": [index_head_dim, hidden_size], + f"{indexer_prefix}.wq_b.weight": [wq_b_dim_out, q_lora_rank], + } + ) + + # --- MLP weights (Dense vs. MoE) --- if layer_idx < first_k_dense: - # This is a DENSE MLP layer (DeepseekV3MLP) + # This is a DENSE MLP layer layer_mapping.update( { f"{layer_prefix}.mlp.gate_proj.weight": [intermediate_size, hidden_size], @@ -313,8 +334,8 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): } ) else: - # This is a MoE MLP layer (DeepseekV3MoE) - # Add the router gate (DeepseekV3TopkRouter) + # This is a MoE MLP layer + # Add the router gate layer_mapping.update( { f"{layer_prefix}.mlp.gate.weight": [n_routed_experts, hidden_size], @@ -322,7 +343,7 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): } ) - # Add routed experts (DeepseekV3NaiveMoe) + # Add routed experts for expert_j in range(n_routed_experts): expert_prefix = f"{layer_prefix}.mlp.experts.{expert_j}" layer_mapping.update( @@ -532,21 +553,21 @@ def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config): def QWEN_HF_WEIGHTS_TO_SHAPE(config): """Returns mapping between HuggingFace Qwen weights path and the HuggingFace weights shape. - To check this mapping, dump the huggingface model shapes: - from transformers import AutoModelForCausalLM - model_name = "Qwen/Qwen3-0.6B" - model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto") - for name, val in model.named_parameters(): - print(name, val.shape) - Args: - config (dict): Model configuration dictionary (from HF Qwen3TextConfig.to_dict()) - Expected keys: https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/config.json + config (dict): HF configuration dictionary (from Qwen3TextConfig.to_dict()) + e.g., https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/config.json Returns: dict: A mapping where: - Keys are HuggingFace model parameter paths - Values are parameter shape as a list + + To check expected mapping: + from transformers import AutoModelForCausalLM + model_name = "Qwen/Qwen3-0.6B" + model = AutoModelForCausalLM.from_pretrained(model_name, dtype="auto") + for name, val in model.named_parameters(): + print(name, val.shape) """ hidden_size = config["hidden_size"] num_hidden_layers = config["num_hidden_layers"] @@ -780,7 +801,9 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config): "qwen3-30b-a3b": QWEN_HF_WEIGHTS_TO_SHAPE, "qwen3-235b-a22b": QWEN_HF_WEIGHTS_TO_SHAPE, "qwen3-480b-a35b": QWEN_HF_WEIGHTS_TO_SHAPE, + "deepseek2-16b": DEEPSEEK_HF_WEIGHTS_TO_SHAPE, "deepseek3-671b": DEEPSEEK_HF_WEIGHTS_TO_SHAPE, + "deepseek3.2-671b": DEEPSEEK_HF_WEIGHTS_TO_SHAPE, "gpt-oss-20b": GPT_OSS_HF_WEIGHTS_TO_SHAPE, "gpt-oss-120b": GPT_OSS_HF_WEIGHTS_TO_SHAPE, "mixtral-8x7b": MIXTRAL_HF_WEIGHTS_TO_SHAPE, diff --git a/src/maxtext/checkpoint_conversion/utils/param_mapping.py b/src/maxtext/checkpoint_conversion/utils/param_mapping.py index 7e318d7fe5..7e918b9128 100644 --- a/src/maxtext/checkpoint_conversion/utils/param_mapping.py +++ b/src/maxtext/checkpoint_conversion/utils/param_mapping.py @@ -1078,10 +1078,6 @@ def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fal scanned (list of strings), unscanned with expert stacking (list of strings), or scanned with expert stacking (nested list of strings). """ - # TODO(shuningjin): add unscan support, b/457820735 - if not scan_layers: - raise NotImplementedError("This conversion only supports scanned MaxText models.") - # Extract hf configuration parameters, without mtp num_main_layers = config["num_hidden_layers"] first_num_dense_layers = config["first_k_dense_replace"] @@ -1097,16 +1093,22 @@ def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fal attention_keys = { "pre_self_attention_layer_norm-scale": "input_layernorm.weight", "post_self_attention_layer_norm-scale": "post_attention_layernorm.weight", - "self_attention-wkv_a-kernel": "self_attn.kv_a_proj_with_mqa.weight", "self_attention-kv_norm-scale": "self_attn.kv_a_layernorm.weight", + "self_attention-wkv_a-kernel": "self_attn.kv_a_proj_with_mqa.weight", "self_attention-wkv_b-kernel": "self_attn.kv_b_proj.weight", "self_attention-out-kernel": "self_attn.o_proj.weight", + # v2 + "self_attention-query-kernel": "self_attn.q_proj.weight", # v3 - "self_attention-wq_a-kernel": "self_attn.q_a_proj.weight", "self_attention-q_norm-scale": "self_attn.q_a_layernorm.weight", + "self_attention-wq_a-kernel": "self_attn.q_a_proj.weight", "self_attention-wq_b-kernel": "self_attn.q_b_proj.weight", - # v2 - "self_attention-query-kernel": "self_attn.q_proj.weight", + # v3.2 + "self_attention-indexer-k_norm-bias": "self_attn.indexer.k_norm.bias", + "self_attention-indexer-k_norm-scale": "self_attn.indexer.k_norm.weight", + "self_attention-indexer-weights_proj-kernel": "self_attn.indexer.weights_proj.weight", + "self_attention-indexer-wk-kernel": "self_attn.indexer.wk.weight", + "self_attention-indexer-wq_b-kernel": "self_attn.indexer.wq_b.weight", } # Dense Layers dense_layer_keys = attention_keys | { @@ -1114,11 +1116,6 @@ def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fal "mlp-wi_1-kernel": "mlp.up_proj.weight", "mlp-wo-kernel": "mlp.down_proj.weight", } - for maxtext_key, hf_key in dense_layer_keys.items(): - mapping[f"params-decoder-dense_layers-{maxtext_key}"] = [ - f"model.layers.{i}.{hf_key}" for i in range(first_num_dense_layers) - ] - # MoE Layers moe_layer_keys = attention_keys | { "DeepSeekMoeBlock_0-shared_experts-wi_0-kernel": "mlp.shared_experts.gate_proj.weight", @@ -1128,33 +1125,51 @@ def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=Fal # v3 "DeepSeekMoeBlock_0-MoeBlock_0-gate-bias": "mlp.gate.e_score_correction_bias", } - for maxtext_key, hf_key in moe_layer_keys.items(): - mapping[f"params-decoder-moe_layers-{maxtext_key}"] = [ - f"model.layers.{i}.{hf_key}" for i in range(first_num_dense_layers, num_main_layers) - ] - # MoE Experts (nested list mapping: [[e0_l0, e0_l1..], [e1_l0, e1_l1..]..]) moe_expert_keys = { "DeepSeekMoeBlock_0-MoeBlock_0-wi_0": "gate_proj.weight", "DeepSeekMoeBlock_0-MoeBlock_0-wi_1": "up_proj.weight", "DeepSeekMoeBlock_0-MoeBlock_0-wo": "down_proj.weight", } - for maxtext_key, hf_key in moe_expert_keys.items(): - mapping[f"params-decoder-moe_layers-{maxtext_key}"] = [ - [f"model.layers.{l}.mlp.experts.{e}.{hf_key}" for l in range(first_num_dense_layers, num_main_layers)] - for e in range(num_experts) - ] + + # scan + if scan_layers: + for maxtext_key, hf_key in dense_layer_keys.items(): + mapping[f"params-decoder-dense_layers-{maxtext_key}"] = [ + f"model.layers.{i}.{hf_key}" for i in range(first_num_dense_layers) + ] + + for maxtext_key, hf_key in moe_layer_keys.items(): + mapping[f"params-decoder-moe_layers-{maxtext_key}"] = [ + f"model.layers.{i}.{hf_key}" for i in range(first_num_dense_layers, num_main_layers) + ] + + for maxtext_key, hf_key in moe_expert_keys.items(): + mapping[f"params-decoder-moe_layers-{maxtext_key}"] = [ + [f"model.layers.{i}.mlp.experts.{e}.{hf_key}" for i in range(first_num_dense_layers, num_main_layers)] + for e in range(num_experts) + ] + # unscan + else: + for i in range(first_num_dense_layers): + for maxtext_key, hf_key in dense_layer_keys.items(): + mapping[f"params-decoder-dense_layers_{i}-{maxtext_key}"] = f"model.layers.{i}.{hf_key}" + + for i in range(first_num_dense_layers, num_main_layers): + moe_layer_idx = i - first_num_dense_layers + + for maxtext_key, hf_key in moe_layer_keys.items(): + mapping[f"params-decoder-moe_layers_{moe_layer_idx}-{maxtext_key}"] = f"model.layers.{i}.{hf_key}" + + for maxtext_key, hf_key in moe_expert_keys.items(): + mapping[f"params-decoder-moe_layers_{moe_layer_idx}-{maxtext_key}"] = [ + f"model.layers.{i}.mlp.experts.{e}.{hf_key}" for e in range(num_experts) + ] return mapping def DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): """Creates parameter transformation functions for Deepseek.""" - # TODO(shuningjin): support hf->orbax(scan), b/457820372 - if not saving_to_hf: - raise NotImplementedError("This conversion only supports saving_to_hf") - # TODO(shuningjin): add unscan support, b/457820735 - if not scan_layers: - raise NotImplementedError("This conversion only supports scanned MaxText models.") def reshape_kernel(input_tensor, target_shape): """Reshapes and transposes kernel weights between MaxText and HF.""" @@ -1164,39 +1179,60 @@ def reshape_kernel(input_tensor, target_shape): else: return input_tensor.T.reshape(target_shape) + num_main_layers = config["num_hidden_layers"] + first_num_dense_layers = config["first_k_dense_replace"] + mapping = { "params-decoder-logits_dense-kernel": reshape_kernel, } - # all keys that need the reshape hook - params_need_reshape = { - # Dense Layers - "params-decoder-dense_layers-self_attention-query-kernel", - "params-decoder-dense_layers-self_attention-wq_a-kernel", - "params-decoder-dense_layers-self_attention-wq_b-kernel", - "params-decoder-dense_layers-self_attention-wkv_a-kernel", - "params-decoder-dense_layers-self_attention-wkv_b-kernel", - "params-decoder-dense_layers-self_attention-out-kernel", - "params-decoder-dense_layers-mlp-wi_0-kernel", - "params-decoder-dense_layers-mlp-wi_1-kernel", - "params-decoder-dense_layers-mlp-wo-kernel", - # MoE Layers - "params-decoder-moe_layers-self_attention-query-kernel", - "params-decoder-moe_layers-self_attention-wq_a-kernel", - "params-decoder-moe_layers-self_attention-wq_b-kernel", - "params-decoder-moe_layers-self_attention-wkv_a-kernel", - "params-decoder-moe_layers-self_attention-wkv_b-kernel", - "params-decoder-moe_layers-self_attention-out-kernel", - "params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wi_0-kernel", - "params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wi_1-kernel", - "params-decoder-moe_layers-DeepSeekMoeBlock_0-shared_experts-wo-kernel", - "params-decoder-moe_layers-DeepSeekMoeBlock_0-MoeBlock_0-gate-kernel", - "params-decoder-moe_layers-DeepSeekMoeBlock_0-MoeBlock_0-wi_0", - "params-decoder-moe_layers-DeepSeekMoeBlock_0-MoeBlock_0-wi_1", - "params-decoder-moe_layers-DeepSeekMoeBlock_0-MoeBlock_0-wo", + + attention_need_reshape = { + "self_attention-wkv_a-kernel", # transpose + "self_attention-wkv_b-kernel", + "self_attention-out-kernel", + # v2 + "self_attention-query-kernel", + # v3 + "self_attention-wq_a-kernel", # transpose + "self_attention-wq_b-kernel", + # v3.2 + "self_attention-indexer-weights_proj-kernel", # transpose + "self_attention-indexer-wk-kernel", # transpose + "self_attention-indexer-wq_b-kernel", + } + + dense_need_reshape = attention_need_reshape | { + "mlp-wi_0-kernel", # transpose + "mlp-wi_1-kernel", # transpose + "mlp-wo-kernel", # transpose + } + + moe_need_reshape = attention_need_reshape | { + "DeepSeekMoeBlock_0-shared_experts-wi_0-kernel", # transpose + "DeepSeekMoeBlock_0-shared_experts-wi_1-kernel", # transpose + "DeepSeekMoeBlock_0-shared_experts-wo-kernel", # transpose + "DeepSeekMoeBlock_0-MoeBlock_0-gate-kernel", # transpose + "DeepSeekMoeBlock_0-MoeBlock_0-wi_0", # transpose + "DeepSeekMoeBlock_0-MoeBlock_0-wi_1", # transpose + "DeepSeekMoeBlock_0-MoeBlock_0-wo", # transpose } - for key in params_need_reshape: - mapping[key] = reshape_kernel + # scan + if scan_layers: + for key in dense_need_reshape: + mapping[f"params-decoder-dense_layers-{key}"] = reshape_kernel + for key in moe_need_reshape: + mapping[f"params-decoder-moe_layers-{key}"] = reshape_kernel + # unscan + else: + for i in range(first_num_dense_layers): + for key in dense_need_reshape: + mapping[f"params-decoder-dense_layers_{i}-{key}"] = reshape_kernel + for i in range(first_num_dense_layers, num_main_layers): + moe_layer_idx = i - first_num_dense_layers + for key in moe_need_reshape: + mapping[f"params-decoder-moe_layers_{moe_layer_idx}-{key}"] = reshape_kernel + return mapping @@ -1747,9 +1783,9 @@ def reshape_audio_attn_qkv(input_tensor, target_shape): def reshape_audio_attn_out(input_tensor, target_shape): """Reshape audio attention output projection. - F - HF: (hidden_size, hidden_size) - MaxText: (num_heads, head_dim, hidden_size) + + HF: (hidden_size, hidden_size) + MaxText: (num_heads, head_dim, hidden_size) """ if saving_to_hf: # MaxText -> HF: (num_heads, head_dim, hidden_size) -> (hidden_size, hidden_size) @@ -2379,7 +2415,9 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "qwen3-30b-a3b-base": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_MAPPING, + "deepseek2-16b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING, "deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING, + "deepseek3.2-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING, @@ -2419,7 +2457,9 @@ def pad_hf_embedding_layer(input_tensor, target_shape): "qwen3-30b-a3b-base": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-235b-a22b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, "qwen3-coder-480b-a35b": QWEN_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "deepseek2-16b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN, "deepseek3-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "deepseek3.2-671b": DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN, "gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN, diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index bb40812ea3..69f6741e26 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -24,29 +24,28 @@ from typing import Any from tqdm import tqdm import resource +import numpy as np +import psutil +import pathlib +from etils import epath import jax from jax.experimental import multihost_utils - from jaxtyping import Array -import numpy as np - from google.cloud.storage import Client, transfer_manager +from safetensors import safe_open from safetensors.numpy import save_file as numpy_save_file from safetensors.numpy import save as numpy_save from safetensors.flax import save as save_flax_to_bytes -from huggingface_hub import HfApi, repo_exists +from huggingface_hub import HfApi, repo_exists, snapshot_download from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers import AutoModelForCausalLM from maxtext.utils import max_logging -import psutil - -from etils import epath import orbax.checkpoint as ocp @@ -165,7 +164,7 @@ def convert_jax_weight_to_numpy(weight: "jax.Array", dtype_str: None | str = Non return np_array.reshape(expected_shape) # Reshape for safety, though usually preserved. -def _process(hf_path, processed_slice, output_weights, current_hook_fns, hf_shape_map): +def _process(hf_path, processed_slice, output_weights, current_hook_fns, hf_shape_map, save_dtype): """Applies hooks, converts a JAX slice to NumPy, and appends it to the output list, used in to_huggingface""" if hf_path not in hf_shape_map: raise ValueError(f"HF path '{hf_path}' not found in hf_shape_map.") @@ -173,7 +172,7 @@ def _process(hf_path, processed_slice, output_weights, current_hook_fns, hf_shap # If hook is unsepecified, use identity if current_hook_fns: processed_slice = apply_hook_fns(processed_slice, target_hf_shape, current_hook_fns) - numpy_slice = convert_jax_weight_to_numpy(processed_slice).squeeze() + numpy_slice = convert_jax_weight_to_numpy(processed_slice, save_dtype).squeeze() if numpy_slice.shape != tuple(target_hf_shape): raise ValueError(f"Shape mismatch for {hf_path}: Expect {target_hf_shape}, got {numpy_slice.shape}") output_weights.append((hf_path, numpy_slice)) @@ -236,7 +235,14 @@ def process_maxtext_param( if not isinstance(hf_target_paths, list): max_logging.log("\tunscan") hf_path = hf_target_paths - _process(hf_path, maxtext_param_weight, output_weights, current_hook_fns, hf_shape_map) + _process( + hf_path, + maxtext_param_weight, + output_weights, + current_hook_fns, + hf_shape_map, + save_dtype=maxtext_config.weight_dtype, + ) return output_weights # Stacked MaxText weight @@ -270,7 +276,14 @@ def process_maxtext_param( else: # For `atomic_mt_key` mappings, slice the single MaxText tensor. weight_slice = jax.lax.index_in_dim(maxtext_param_weight, i, axis=axis_to_slice, keepdims=False) - _process(hf_path, weight_slice, output_weights, current_hook_fns, hf_shape_map) + _process( + hf_path, + weight_slice, + output_weights, + current_hook_fns, + hf_shape_map, + save_dtype=maxtext_config.weight_dtype, + ) return output_weights @@ -292,7 +305,14 @@ def process_maxtext_param( # Slice the expert tensor along the layer axis to get the final individual weight. # axis is 0 on the new sliced tensor layer_tensor_slice = jax.lax.index_in_dim(expert_tensor_slice, layer_idx, axis=0, keepdims=False) - _process(hf_path, layer_tensor_slice, output_weights, current_hook_fns, hf_shape_map) + _process( + hf_path, + layer_tensor_slice, + output_weights, + current_hook_fns, + hf_shape_map, + save_dtype=maxtext_config.weight_dtype, + ) return output_weights @@ -903,7 +923,7 @@ def detect_and_extract_checkpoint(checkpoint_dict: dict) -> dict[str, np.ndarray return extract_linen_weights(actual_weights_dict) -def get_hf_model(model_id: str, token: str, revision: str = None): +def load_hf_dict_from_transformers(model_id: str, token: str, revision: str = None, dtype: str = "auto"): """Loads the HuggingFace model based on model_id (Eager mode only), used in to_maxtext""" if model_id in ["Qwen/Qwen3-Omni-30B-A3B-Instruct"]: from transformers import Qwen3OmniMoeForConditionalGeneration # pylint: disable=import-outside-toplevel @@ -912,5 +932,35 @@ def get_hf_model(model_id: str, token: str, revision: str = None): else: model_class = AutoModelForCausalLM - hf_model = model_class.from_pretrained(model_id, token=token, revision=revision) - return hf_model + # Note: transformers deprecates `torch_dtype` in favor of standard `dtype` in model loading + hf_model = model_class.from_pretrained(model_id, token=token, revision=revision, dtype=dtype) + + return hf_model.state_dict() + + +def load_hf_dict_from_safetensors(model_id_or_path, token, revision, framework="pt"): + """ + If the safetensor contains more HF keys than MaxText model, + these HF keys will be loaded but ignored during conversion. + For example, if maxtext has deepseek3 with mtp=false, + then safetensor weight with prefix `model.layers.61` will be the extra keys. + """ + # Determine if the path is local or remote + if os.path.isdir(model_id_or_path): + local_path = model_id_or_path + else: + # Download only the safetensors files to the local HF cache + local_path = snapshot_download( + repo_id=model_id_or_path, + token=token, + revision=revision, + ) + # load safetensors + ckpt_paths = sorted(pathlib.Path(local_path).glob("[!.]*.safetensors")) + hf_state_dict = {} + max_logging.log(f"Loading {len(ckpt_paths)} checkpoints") + for ckpt_path in tqdm(ckpt_paths, total=len(ckpt_paths)): + with safe_open(ckpt_path, framework=framework, device="cpu") as f: + for key in f.keys(): + hf_state_dict[key] = f.get_tensor(key) + return hf_state_dict diff --git a/src/maxtext/utils/globals.py b/src/maxtext/utils/globals.py index 203d7a6165..e787cb7176 100644 --- a/src/maxtext/utils/globals.py +++ b/src/maxtext/utils/globals.py @@ -66,7 +66,9 @@ "qwen3-30b-a3b": "Qwen/Qwen3-30B-A3B-Thinking-2507", "qwen3-235b-a22b": "Qwen/Qwen3-235B-A22B-Thinking-2507", "qwen3-480b-a35b": "Qwen/Qwen3-Coder-480B-A35B-Instruct", + "deepseek2-16b": "deepseek-ai/DeepSeek-V2-Lite", "deepseek3-671b": "deepseek-ai/DeepSeek-V3", + "deepseek3.2-671b": "deepseek-ai/DeepSeek-V3.2", "gpt-oss-20b": "openai/gpt-oss-20b", "gpt-oss-120b": "openai/gpt-oss-120b", "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", diff --git a/src/maxtext/utils/muon_utils.py b/src/maxtext/utils/muon_utils.py index 5435633905..f50acd269f 100644 --- a/src/maxtext/utils/muon_utils.py +++ b/src/maxtext/utils/muon_utils.py @@ -168,7 +168,7 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): if __name__ == "__main__": if len(sys.argv) != 3: - print("Usage: python3 -m MaxText.muon_utils ") + print("Usage: python3 -m maxtext.utils.muon_utils ") sys.exit(1) model_name_arg = sys.argv[1] scan_layers_arg = sys.argv[2].lower() == "true" diff --git a/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh index 9f5d4c320b..b0bf20480a 100644 --- a/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh +++ b/tests/end_to_end/tpu/deepseek/v2-16b/test_deepseek.sh @@ -33,7 +33,7 @@ BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} # Step 1: Checkpoint conversion -# You can use the HuggingFace checkpoint at https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite, and dequantize it to bf16 +# You can use the HuggingFace checkpoint at https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite (bf16) # Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET # Non-Googlers please remember to point `CKPT_BUCKET` to GCS buckets that you own # Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory diff --git a/tests/end_to_end/tpu/deepseek/v3-671b/1_test_deepseek.sh b/tests/end_to_end/tpu/deepseek/v3-671b/1_test_deepseek.sh index 987157aa51..660d221338 100644 --- a/tests/end_to_end/tpu/deepseek/v3-671b/1_test_deepseek.sh +++ b/tests/end_to_end/tpu/deepseek/v3-671b/1_test_deepseek.sh @@ -25,7 +25,7 @@ BASE_OUTPUT_PATH=${BASE_OUTPUT_PATH%/} echo using BASE_OUTPUT_PATH = ${BASE_OUTPUT_PATH} # Step 1: Checkpoint conversion -# You can use the HuggingFace checkpoint at https://huggingface.co/deepseek-ai/DeepSeek-V3, and dequantize it to bf16 +# You can use the HuggingFace checkpoint at https://huggingface.co/deepseek-ai/DeepSeek-V3 (fp8), and dequantize it to bf16 # Assume HF checkpoints are uploaded to GCS bucket at CKPT_BUCKET # Non-Googlers please remember to point `CKPT_BUCKET` to GCS buckets that you own # Copying the HF checkpoint into a local directory `/tmp` -- you are free to use a different directory diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index b7a31acae6..1b59daf00e 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -166,7 +166,7 @@ def check_kl_divergence(model_logits, golden_logits, atol=0.02): log_target=False, ) - max_logging.log(f"\nAverage KL divergence per token (D_KL(P_golden || Q_model)): {kl_div_value.item():.6f}") + max_logging.log(f"\nAverage KL divergence per token (D_KL(P_golden || Q_model)): {kl_div_value.item():.4e}") # To find the max KL divergence for any single token in the set # use reduction='none'. @@ -177,13 +177,13 @@ def check_kl_divergence(model_logits, golden_logits, atol=0.02): ) # Sum over the vocab dim to get a single KL value per token # Log per-token KL divergences - formatted_list = [f"{x:.6f}" for x in kl_divs_per_token.tolist()] + formatted_list = [f"{x:.4e}" for x in kl_divs_per_token.tolist()] max_logging.log(f"Per-token KL Divergences: \n{formatted_list}") max_kl_div = kl_divs_per_token.max() - max_logging.log(f"\nMax KL divergence for a single token in the set: {max_kl_div.item():.6f}") + max_logging.log(f"\nMax KL divergence for a single token in the set: {max_kl_div.item():.4e}") - assert max_kl_div < atol, f"KL divergence values {max_kl_div.item():.6f} exceed the threshold {atol}" + assert max_kl_div < atol, f"KL divergence values {max_kl_div.item():.4e} exceed the threshold {atol}" def get_data(golden_data_point, config): @@ -324,10 +324,10 @@ def main(config, test_args): # pylint: disable=W0621 max_rel_diff_val = rel_diff[max_rel_diff_idx] msg = ( "\n[numerical difference]\n" - f"Max absolute difference: {max_abs_diff_val:.6f} at index {max_abs_diff_idx}\n" - f" (Train: {train_logits_slice[max_abs_diff_idx]:.6f}, Golden: {golden_logits_slice[max_abs_diff_idx]:.6f})\n" - f"Max relative difference: {max_rel_diff_val:.6f} at index {max_rel_diff_idx}\n" - f" (Train: {train_logits_slice[max_rel_diff_idx]:.6f}, Golden: {golden_logits_slice[max_rel_diff_idx]:.6f})" + f"Max absolute difference: {max_abs_diff_val:.4e} at index {max_abs_diff_idx}\n" + f" (Train: {train_logits_slice[max_abs_diff_idx]:.4e}, Golden: {golden_logits_slice[max_abs_diff_idx]:.4e})\n" + f"Max relative difference: {max_rel_diff_val:.4e} at index {max_rel_diff_idx}\n" + f" (Train: {train_logits_slice[max_rel_diff_idx]:.4e}, Golden: {golden_logits_slice[max_rel_diff_idx]:.4e})" ) max_logging.log(msg) @@ -532,25 +532,11 @@ def main(config, test_args): # pylint: disable=W0621 default=False, help="Skip the first token during comparison to ignore BOS/init mismatches.", ) - test_args, _ = parser.parse_known_args() - - # Remove args defined in this test file to avoid error from pyconfig - model_args = sys.argv - to_remove_args = [ - "--atol", - "--rtol", - "--token_size", - "--max_kl_div", - "--golden_logits_path", - "--hf_model_path", - "--run_hf_model", - "--output_logits_path", - "--gcs_output_logits_path", - "--clip_logits_epsilon", - "--skip_first_token", - ] - for arg in to_remove_args: - model_args = [s for s in model_args if not s.startswith(arg)] + + # Parse known args returns the namespace AND the list of remaining arguments + test_args, remaining_args = parser.parse_known_args() + # Reconstruct model_args (script name + the args MaxText needs) + model_args = [sys.argv[0]] + remaining_args cfg = pyconfig.initialize(model_args) assert (