Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions src/maxtext/checkpoint_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,25 @@ def _get_model_mappings(
if model_name not in PARAM_MAPPING or model_name not in HF_SHAPE or model_name not in HOOK_FNS:
raise ValueError(f"Mappings not found for model: {model_name}. Available PARAM_MAPPING keys: {PARAM_MAPPING.keys()}")

param_mapping = PARAM_MAPPING[model_name](hf_config_dict, maxtext_config, scan_layers)
hook_fn_mapping = HOOK_FNS[model_name](hf_config_dict, maxtext_config, scan_layers, saving_to_hf=True)

# Promote composite hook keys into param_mapping.
# If HOOK_FNS defines a composite tuple key (e.g., (wi_0, wi_1) for MoE gate_up_proj),
# replace the individual entries in param_mapping with one composite entry so
# process_maxtext_param receives both tensors together and passes them to the hook.
for hook_key in list(hook_fn_mapping.keys()):
if isinstance(hook_key, tuple):
hf_path = param_mapping.get(hook_key[0])
if hf_path is not None:
param_mapping[hook_key] = hf_path
for k in hook_key:
param_mapping.pop(k, None)

return {
"param_mapping": PARAM_MAPPING[model_name](hf_config_dict, maxtext_config, scan_layers),
"param_mapping": param_mapping,
"shape_mapping": HF_SHAPE[model_name](hf_config_dict),
"hook_fn_mapping": HOOK_FNS[model_name](hf_config_dict, maxtext_config, scan_layers, saving_to_hf=True),
"hook_fn_mapping": hook_fn_mapping,
}


Expand Down Expand Up @@ -289,6 +304,17 @@ def main(argv: Sequence[str]) -> None:
# Validate that checkpoint keys match the parameter mapping
filtered_map_keys = validate_and_filter_param_map_keys(param_map.keys(), maxtext_state_dict.keys())

# When not converting a multimodal model, skip vision encoder weights even if
# they are present in the checkpoint (e.g. converting text-only from a
# multimodal checkpoint).
if not config.use_multimodal:
filtered_map_keys = [
k
for k in filtered_map_keys
if not (isinstance(k, str) and "vision_encoder" in k)
and not (isinstance(k, tuple) and any("vision_encoder" in sub for sub in k))
]

# Iterate through the parameter map to transform and collect weights.
# This loop handles both simple 1-to-1 mappings and complex N-to-1 mappings
# (where multiple MaxText weights are combined into a single HF weight).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"dtype": "bfloat16",
"enable_moe_block": True,
"eos_token_id": 1,
"expert_intermediate_size": 704,
"moe_intermediate_size": 704,
"final_logit_softcapping": 30.0,
"global_head_dim": 512,
"head_dim": 256,
Expand Down Expand Up @@ -123,7 +123,6 @@
gemma4_31b_dict["text_config"].update(
{
"enable_moe_block": False,
"expert_intermediate_size": None,
"hidden_size": 5376,
"intermediate_size": 21504,
"layer_types": [
Expand Down
133 changes: 133 additions & 0 deletions src/maxtext/checkpoint_conversion/utils/hf_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,137 @@ def GEMMA3_HF_WEIGHTS_TO_SHAPE(config):
return shapes


def GEMMA4_HF_WEIGHTS_TO_SHAPE(config):
"""Generates shape mapping for Hugging Face Gemma4 parameters.

Handles both multimodal (with vision tower) and text-only variants, as well
as MoE (26B) and dense (31B) text configurations. Shapes are per-layer aware:
local (sliding) attention layers use head_dim, while global (full) attention
layers use global_head_dim and num_global_key_value_heads.

Args:
config (dict): The Hugging Face model configuration dictionary. Must contain
'text_config' with architectural details. May contain 'vision_config' for
multimodal models.

Returns:
dict: A dictionary mapping Hugging Face parameter names to their shapes.
"""
shapes = {}

text_cfg = config.get("text_config", config)
vision_cfg = config.get("vision_config", {})
# text_base matches GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING logic
text_base = "model.language_model" if vision_cfg else "model"

hidden_size = text_cfg["hidden_size"]
intermediate_size = text_cfg["intermediate_size"]
num_hidden_layers = text_cfg["num_hidden_layers"]
num_attention_heads = text_cfg["num_attention_heads"]
num_key_value_heads = text_cfg["num_key_value_heads"]
num_global_key_value_heads = text_cfg.get("num_global_key_value_heads", num_key_value_heads)
head_dim = text_cfg["head_dim"]
global_head_dim = text_cfg.get("global_head_dim", head_dim)
vocab_size = text_cfg["vocab_size"]

num_experts = text_cfg.get("num_experts")
num_experts = num_experts if num_experts is not None else 1
# "moe_intermediate_size" is the canonical key in Gemma4 config; fall back to "expert_intermediate_size"
expert_intermediate_size = text_cfg.get("moe_intermediate_size") or text_cfg.get("expert_intermediate_size")

shapes[f"{text_base}.embed_tokens.weight"] = [vocab_size, hidden_size]
shapes[f"{text_base}.norm.weight"] = [hidden_size]

for i in range(num_hidden_layers):
hf_prefix = f"{text_base}.layers.{i}"
is_global = (i % 6) == 5

if is_global:
q_dim = num_attention_heads * global_head_dim
kv_dim = num_global_key_value_heads * global_head_dim
norm_dim = global_head_dim
else:
q_dim = num_attention_heads * head_dim
kv_dim = num_key_value_heads * head_dim
norm_dim = head_dim

shapes[f"{hf_prefix}.self_attn.q_proj.weight"] = [q_dim, hidden_size]
shapes[f"{hf_prefix}.self_attn.k_proj.weight"] = [kv_dim, hidden_size]
shapes[f"{hf_prefix}.self_attn.v_proj.weight"] = [kv_dim, hidden_size]
shapes[f"{hf_prefix}.self_attn.o_proj.weight"] = [hidden_size, q_dim]
shapes[f"{hf_prefix}.self_attn.q_norm.weight"] = [norm_dim]
shapes[f"{hf_prefix}.self_attn.k_norm.weight"] = [norm_dim]
# v_norm is conditional on maxtext_config.v_norm_with_scale; included here for completeness
shapes[f"{hf_prefix}.self_attn.v_norm.weight"] = [norm_dim]

shapes[f"{hf_prefix}.input_layernorm.weight"] = [hidden_size]
shapes[f"{hf_prefix}.post_attention_layernorm.weight"] = [hidden_size]
shapes[f"{hf_prefix}.pre_feedforward_layernorm.weight"] = [hidden_size]
shapes[f"{hf_prefix}.post_feedforward_layernorm.weight"] = [hidden_size]
shapes[f"{hf_prefix}.layer_scalar"] = [1]

if num_experts > 1:
shapes[f"{hf_prefix}.pre_feedforward_layernorm_2.weight"] = [hidden_size]
shapes[f"{hf_prefix}.post_feedforward_layernorm_1.weight"] = [hidden_size]
shapes[f"{hf_prefix}.post_feedforward_layernorm_2.weight"] = [hidden_size]
# router.scale has shape [hidden_size] (pre_forward_scale_2 in MaxText)
shapes[f"{hf_prefix}.router.scale"] = [hidden_size]
shapes[f"{hf_prefix}.router.proj.weight"] = [num_experts, hidden_size]
shapes[f"{hf_prefix}.router.per_expert_scale"] = [num_experts]
# Routed experts fused: gate_up [E, 2*FF, H], down [E, H, FF]
shapes[f"{hf_prefix}.experts.gate_up_proj"] = [num_experts, 2 * expert_intermediate_size, hidden_size]
shapes[f"{hf_prefix}.experts.down_proj"] = [num_experts, hidden_size, expert_intermediate_size]
# Shared expert dense MLP
shapes[f"{hf_prefix}.mlp.gate_proj.weight"] = [intermediate_size, hidden_size]
shapes[f"{hf_prefix}.mlp.up_proj.weight"] = [intermediate_size, hidden_size]
shapes[f"{hf_prefix}.mlp.down_proj.weight"] = [hidden_size, intermediate_size]
else:
shapes[f"{hf_prefix}.mlp.gate_proj.weight"] = [intermediate_size, hidden_size]
shapes[f"{hf_prefix}.mlp.up_proj.weight"] = [intermediate_size, hidden_size]
shapes[f"{hf_prefix}.mlp.down_proj.weight"] = [hidden_size, intermediate_size]

if vision_cfg:
vis_hidden = vision_cfg["hidden_size"]
vis_intermediate = vision_cfg["intermediate_size"]
vis_num_layers = vision_cfg["num_hidden_layers"]
vis_num_heads = vision_cfg["num_attention_heads"]
vis_head_dim = vision_cfg["head_dim"]
vis_q_dim = vis_num_heads * vis_head_dim
vis_kv_heads = vision_cfg.get("num_key_value_heads", vis_num_heads)
vis_kv_dim = vis_kv_heads * vis_head_dim
vis_pos_emb_size = vision_cfg.get("position_embedding_size", 10240)
vis_patch_size = vision_cfg.get("patch_size", 16)
num_channels = 3 # RGB
patch_flat = num_channels * vis_patch_size * vis_patch_size

# VisionEntry: input_proj is a linear [patch_flat, vis_hidden] transposed to [vis_hidden, patch_flat]
shapes["model.vision_tower.patch_embedder.input_proj.weight"] = [vis_hidden, patch_flat]
# pos_emb_param MaxText shape (N, 2, D) -> transpose(1,0,2) -> HF (2, N, D)
shapes["model.vision_tower.patch_embedder.position_embedding_table"] = [2, vis_pos_emb_size, vis_hidden]
shapes["model.vision_tower.std_scale"] = [vis_hidden]
shapes["model.vision_tower.std_bias"] = [vis_hidden]
# Vision projector: [vis_hidden, hidden_size] -> reshape_kernel -> [hidden_size, vis_hidden]
shapes["model.embed_vision.embedding_projection.weight"] = [hidden_size, vis_hidden]

for i in range(vis_num_layers):
vis_prefix = f"model.vision_tower.encoder.layers.{i}"
shapes[f"{vis_prefix}.self_attn.q_proj.linear.weight"] = [vis_q_dim, vis_hidden]
shapes[f"{vis_prefix}.self_attn.k_proj.linear.weight"] = [vis_kv_dim, vis_hidden]
shapes[f"{vis_prefix}.self_attn.v_proj.linear.weight"] = [vis_kv_dim, vis_hidden]
shapes[f"{vis_prefix}.self_attn.o_proj.linear.weight"] = [vis_hidden, vis_q_dim]
shapes[f"{vis_prefix}.self_attn.q_norm.weight"] = [vis_head_dim]
shapes[f"{vis_prefix}.self_attn.k_norm.weight"] = [vis_head_dim]
shapes[f"{vis_prefix}.input_layernorm.weight"] = [vis_hidden]
shapes[f"{vis_prefix}.post_attention_layernorm.weight"] = [vis_hidden]
shapes[f"{vis_prefix}.pre_feedforward_layernorm.weight"] = [vis_hidden]
shapes[f"{vis_prefix}.post_feedforward_layernorm.weight"] = [vis_hidden]
shapes[f"{vis_prefix}.mlp.gate_proj.linear.weight"] = [vis_intermediate, vis_hidden]
shapes[f"{vis_prefix}.mlp.up_proj.linear.weight"] = [vis_intermediate, vis_hidden]
shapes[f"{vis_prefix}.mlp.down_proj.linear.weight"] = [vis_hidden, vis_intermediate]

return shapes


def GEMMA2_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace weights path and weights shape.

Expand Down Expand Up @@ -787,6 +918,8 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config):
"gemma3-4b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
"gemma3-12b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
"gemma3-27b": GEMMA3_HF_WEIGHTS_TO_SHAPE,
"gemma4-26b": GEMMA4_HF_WEIGHTS_TO_SHAPE,
"gemma4-31b": GEMMA4_HF_WEIGHTS_TO_SHAPE,
"qwen2.5-1.5b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen2.5-7b": QWEN_HF_WEIGHTS_TO_SHAPE,
"qwen2.5-14b": QWEN_HF_WEIGHTS_TO_SHAPE,
Expand Down
39 changes: 31 additions & 8 deletions src/maxtext/checkpoint_conversion/utils/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,7 +2240,7 @@ def GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False
"params-token_embedder-embedding": f"{text_base}.embed_tokens.weight",
"params-decoder-decoder_norm-scale": f"{text_base}.norm.weight",
}
if maxtext_config.use_multimodal and vcfg:
if vcfg:
nvis = vcfg.get("num_hidden_layers", 0)
mapping.update(
{
Expand Down Expand Up @@ -2577,6 +2577,14 @@ def reshape_moe_wo(input_tensor, target_shape):
# input_tensor: [E, H, FF], target: [E, FF, H]
return input_tensor.transpose(0, 2, 1)

def moe_gate_up_hook(weight_list, target_shape):
# Inverse of split_moe_wi0/wi1: fuse MaxText wi_0, wi_1 → HF experts.gate_up_proj.
# weight_list: [wi_0, wi_1], each [..., H, FF]
# Returns: [..., 2*FF, H]
wi_0 = jnp.asarray(weight_list[0])
wi_1 = jnp.asarray(weight_list[1])
return jnp.swapaxes(jnp.concatenate([wi_0, wi_1], axis=-1), -2, -1)

hooks["params-token_embedder-embedding"] = pad_hf_embedding_layer
hooks["params-decoder-decoder_norm-scale"] = scale_rmsnorm_layer
# REMOVED: logits_dense-kernel hook (handled by logits_via_embedding: True)
Expand Down Expand Up @@ -2631,7 +2639,7 @@ def reshape_moe_wo(input_tensor, target_shape):
# of norm_keys means they perfectly default to the identity mapping.

vcfg = config.get("vision_config", {})
if maxtext_config.use_multimodal and vcfg:
if vcfg:
nvis = vcfg.get("num_hidden_layers", 0)

def reshape_vision_patch(x, target_shape):
Expand Down Expand Up @@ -2680,8 +2688,13 @@ def reshape_pos_emb(x, target_shape):
hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer

# Add these specialized 3D tensor hooks inside the loop
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1
if saving_to_hf and num_experts > 1:
wi0_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"
wi1_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"
hooks[(wi0_key, wi1_key)] = moe_gate_up_hook
else:
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wo"] = reshape_moe_wo

# Remainder sub-layer prefixes
Expand All @@ -2700,8 +2713,13 @@ def reshape_pos_emb(x, target_shape):
hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer

# Add these specialized 3D tensor hooks inside the loop
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1
if saving_to_hf and num_experts > 1:
wi0_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"
wi1_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"
hooks[(wi0_key, wi1_key)] = moe_gate_up_hook
else:
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wo"] = reshape_moe_wo
else:
for i in range(nlayers):
Expand All @@ -2717,8 +2735,13 @@ def reshape_pos_emb(x, target_shape):
hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer

# Add these specialized 3D tensor hooks inside the loop
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1
if saving_to_hf and num_experts > 1:
wi0_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"
wi1_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"
hooks[(wi0_key, wi1_key)] = moe_gate_up_hook
else:
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1
hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wo"] = reshape_moe_wo
return hooks

Expand Down
12 changes: 9 additions & 3 deletions src/maxtext/checkpoint_conversion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,19 @@ def validate_and_filter_param_map_keys(param_map_keys, maxtext_state_keys):
# 1 Validate: every maxtext state key must be covered by param map
missing_keys = maxtext_state_keys - flattened_map_keys
if missing_keys:
hint = ""
ckpt_has_scanned = any("scanned_blocks" in k for k in missing_keys)
map_has_scanned = any("scanned_blocks" in k for k in flattened_map_keys)
if ckpt_has_scanned and not map_has_scanned:
hint = "\nHint: checkpoint keys contain 'scanned_blocks' but param_map does not — try scan_layers=True."
elif map_has_scanned and not ckpt_has_scanned:
hint = "\nHint: param_map contains 'scanned_blocks' keys but checkpoint does not — try scan_layers=False."
raise ValueError(
"maxtext_state_dict must be a subset of flattened param_map"
+ f"\nparam map\n{param_map_keys}"
+ f"\nmaxtext:\n{maxtext_state_keys}"
+ f"\nmissing keys:\n{missing_keys}"
+ hint
)

# 2 Filter: param map may have extra keys
Expand Down Expand Up @@ -172,9 +180,7 @@ def _process(hf_path, processed_slice, output_weights, current_hook_fns, hf_shap
# If hook is unsepecified, use identity
if current_hook_fns:
processed_slice = apply_hook_fns(processed_slice, target_hf_shape, current_hook_fns)
numpy_slice = convert_jax_weight_to_numpy(processed_slice, save_dtype).squeeze()
if numpy_slice.shape != tuple(target_hf_shape):
raise ValueError(f"Shape mismatch for {hf_path}: Expect {target_hf_shape}, got {numpy_slice.shape}")
numpy_slice = convert_jax_weight_to_numpy(processed_slice, save_dtype).reshape(target_hf_shape)
output_weights.append((hf_path, numpy_slice))


Expand Down
12 changes: 9 additions & 3 deletions tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ set -ex
idx=$(date +%Y-%m-%d-%H-%M)

MODEL_NAME='gemma4-26b'
export MODEL_VARIATION='26b'
export MODEL_VARIATION='26b-it'
TOKENIZER_PATH='google/gemma-4-26b-a4b-it'
# To convert the multimodal model, make sure the use_multimodal is set to be true
USE_MULTIMODAL=false
Expand All @@ -30,7 +30,9 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MA

export MAXTEXT_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/converted/${idx}/0/items


# Run forward pass logit checker to validate the converted checkpoint.
# The *_tile_fwd_*_dim flags are for reducing vmem usage to fit into v5p chips,
# not for performance purpose.
if [ ${USE_MULTIMODAL} == true ]; then
# Set the shared multimodal prompt and image
TEST_PROMPT='Describe image <|image|>'
Expand All @@ -55,13 +57,13 @@ if [ ${USE_MULTIMODAL} == true ]; then
dtype=float32 \
matmul_precision=highest \
per_device_batch_size=1 \
vision_output_length=280 \
attention=dot_product \
prompt="${TEST_PROMPT}" \
image_path=${TEST_IMAGE} \
--max_kl_div=0.03 \
--golden_logits_path=${GOLDEN_LOGITS_PATH}
else
echo "=== Running MaxText Forward Pass Logit Checker ==="
python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \
tokenizer_path=${TOKENIZER_PATH} \
load_parameters_path=${MAXTEXT_CKPT_PATH} \
Expand All @@ -70,6 +72,10 @@ else
scan_layers=${USE_SCAN_LAYERS} \
per_device_batch_size=1 \
dtype=float32 \
wi_tile_fwd_embed_dim=512 \
wi_tile_fwd_mlp_dim=512 \
wo_tile_fwd_embed_dim=512 \
wo_tile_fwd_mlp_dim=512 \
--max_kl_div=0.03 \
--run_hf_model=true \
--hf_model_path=${HF_MODEL}
Expand Down
Loading
Loading