diff --git a/src/maxtext/checkpoint_conversion/to_huggingface.py b/src/maxtext/checkpoint_conversion/to_huggingface.py index 55f14b440d..86d7eac812 100644 --- a/src/maxtext/checkpoint_conversion/to_huggingface.py +++ b/src/maxtext/checkpoint_conversion/to_huggingface.py @@ -121,10 +121,25 @@ def _get_model_mappings( if model_name not in PARAM_MAPPING or model_name not in HF_SHAPE or model_name not in HOOK_FNS: raise ValueError(f"Mappings not found for model: {model_name}. Available PARAM_MAPPING keys: {PARAM_MAPPING.keys()}") + param_mapping = PARAM_MAPPING[model_name](hf_config_dict, maxtext_config, scan_layers) + hook_fn_mapping = HOOK_FNS[model_name](hf_config_dict, maxtext_config, scan_layers, saving_to_hf=True) + + # Promote composite hook keys into param_mapping. + # If HOOK_FNS defines a composite tuple key (e.g., (wi_0, wi_1) for MoE gate_up_proj), + # replace the individual entries in param_mapping with one composite entry so + # process_maxtext_param receives both tensors together and passes them to the hook. + for hook_key in list(hook_fn_mapping.keys()): + if isinstance(hook_key, tuple): + hf_path = param_mapping.get(hook_key[0]) + if hf_path is not None: + param_mapping[hook_key] = hf_path + for k in hook_key: + param_mapping.pop(k, None) + return { - "param_mapping": PARAM_MAPPING[model_name](hf_config_dict, maxtext_config, scan_layers), + "param_mapping": param_mapping, "shape_mapping": HF_SHAPE[model_name](hf_config_dict), - "hook_fn_mapping": HOOK_FNS[model_name](hf_config_dict, maxtext_config, scan_layers, saving_to_hf=True), + "hook_fn_mapping": hook_fn_mapping, } @@ -289,6 +304,17 @@ def main(argv: Sequence[str]) -> None: # Validate that checkpoint keys match the parameter mapping filtered_map_keys = validate_and_filter_param_map_keys(param_map.keys(), maxtext_state_dict.keys()) + # When not converting a multimodal model, skip vision encoder weights even if + # they are present in the checkpoint (e.g. converting text-only from a + # multimodal checkpoint). + if not config.use_multimodal: + filtered_map_keys = [ + k + for k in filtered_map_keys + if not (isinstance(k, str) and "vision_encoder" in k) + and not (isinstance(k, tuple) and any("vision_encoder" in sub for sub in k)) + ] + # Iterate through the parameter map to transform and collect weights. # This loop handles both simple 1-to-1 mappings and complex N-to-1 mappings # (where multiple MaxText weights are combined into a single HF weight). diff --git a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py index 7eda8abe0b..c78046bd56 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_model_configs.py @@ -47,7 +47,7 @@ "dtype": "bfloat16", "enable_moe_block": True, "eos_token_id": 1, - "expert_intermediate_size": 704, + "moe_intermediate_size": 704, "final_logit_softcapping": 30.0, "global_head_dim": 512, "head_dim": 256, @@ -123,7 +123,6 @@ gemma4_31b_dict["text_config"].update( { "enable_moe_block": False, - "expert_intermediate_size": None, "hidden_size": 5376, "intermediate_size": 21504, "layer_types": [ diff --git a/src/maxtext/checkpoint_conversion/utils/hf_shape.py b/src/maxtext/checkpoint_conversion/utils/hf_shape.py index 973855dbeb..563401e100 100644 --- a/src/maxtext/checkpoint_conversion/utils/hf_shape.py +++ b/src/maxtext/checkpoint_conversion/utils/hf_shape.py @@ -153,6 +153,137 @@ def GEMMA3_HF_WEIGHTS_TO_SHAPE(config): return shapes +def GEMMA4_HF_WEIGHTS_TO_SHAPE(config): + """Generates shape mapping for Hugging Face Gemma4 parameters. + + Handles both multimodal (with vision tower) and text-only variants, as well + as MoE (26B) and dense (31B) text configurations. Shapes are per-layer aware: + local (sliding) attention layers use head_dim, while global (full) attention + layers use global_head_dim and num_global_key_value_heads. + + Args: + config (dict): The Hugging Face model configuration dictionary. Must contain + 'text_config' with architectural details. May contain 'vision_config' for + multimodal models. + + Returns: + dict: A dictionary mapping Hugging Face parameter names to their shapes. + """ + shapes = {} + + text_cfg = config.get("text_config", config) + vision_cfg = config.get("vision_config", {}) + # text_base matches GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING logic + text_base = "model.language_model" if vision_cfg else "model" + + hidden_size = text_cfg["hidden_size"] + intermediate_size = text_cfg["intermediate_size"] + num_hidden_layers = text_cfg["num_hidden_layers"] + num_attention_heads = text_cfg["num_attention_heads"] + num_key_value_heads = text_cfg["num_key_value_heads"] + num_global_key_value_heads = text_cfg.get("num_global_key_value_heads", num_key_value_heads) + head_dim = text_cfg["head_dim"] + global_head_dim = text_cfg.get("global_head_dim", head_dim) + vocab_size = text_cfg["vocab_size"] + + num_experts = text_cfg.get("num_experts") + num_experts = num_experts if num_experts is not None else 1 + # "moe_intermediate_size" is the canonical key in Gemma4 config; fall back to "expert_intermediate_size" + expert_intermediate_size = text_cfg.get("moe_intermediate_size") or text_cfg.get("expert_intermediate_size") + + shapes[f"{text_base}.embed_tokens.weight"] = [vocab_size, hidden_size] + shapes[f"{text_base}.norm.weight"] = [hidden_size] + + for i in range(num_hidden_layers): + hf_prefix = f"{text_base}.layers.{i}" + is_global = (i % 6) == 5 + + if is_global: + q_dim = num_attention_heads * global_head_dim + kv_dim = num_global_key_value_heads * global_head_dim + norm_dim = global_head_dim + else: + q_dim = num_attention_heads * head_dim + kv_dim = num_key_value_heads * head_dim + norm_dim = head_dim + + shapes[f"{hf_prefix}.self_attn.q_proj.weight"] = [q_dim, hidden_size] + shapes[f"{hf_prefix}.self_attn.k_proj.weight"] = [kv_dim, hidden_size] + shapes[f"{hf_prefix}.self_attn.v_proj.weight"] = [kv_dim, hidden_size] + shapes[f"{hf_prefix}.self_attn.o_proj.weight"] = [hidden_size, q_dim] + shapes[f"{hf_prefix}.self_attn.q_norm.weight"] = [norm_dim] + shapes[f"{hf_prefix}.self_attn.k_norm.weight"] = [norm_dim] + # v_norm is conditional on maxtext_config.v_norm_with_scale; included here for completeness + shapes[f"{hf_prefix}.self_attn.v_norm.weight"] = [norm_dim] + + shapes[f"{hf_prefix}.input_layernorm.weight"] = [hidden_size] + shapes[f"{hf_prefix}.post_attention_layernorm.weight"] = [hidden_size] + shapes[f"{hf_prefix}.pre_feedforward_layernorm.weight"] = [hidden_size] + shapes[f"{hf_prefix}.post_feedforward_layernorm.weight"] = [hidden_size] + shapes[f"{hf_prefix}.layer_scalar"] = [1] + + if num_experts > 1: + shapes[f"{hf_prefix}.pre_feedforward_layernorm_2.weight"] = [hidden_size] + shapes[f"{hf_prefix}.post_feedforward_layernorm_1.weight"] = [hidden_size] + shapes[f"{hf_prefix}.post_feedforward_layernorm_2.weight"] = [hidden_size] + # router.scale has shape [hidden_size] (pre_forward_scale_2 in MaxText) + shapes[f"{hf_prefix}.router.scale"] = [hidden_size] + shapes[f"{hf_prefix}.router.proj.weight"] = [num_experts, hidden_size] + shapes[f"{hf_prefix}.router.per_expert_scale"] = [num_experts] + # Routed experts fused: gate_up [E, 2*FF, H], down [E, H, FF] + shapes[f"{hf_prefix}.experts.gate_up_proj"] = [num_experts, 2 * expert_intermediate_size, hidden_size] + shapes[f"{hf_prefix}.experts.down_proj"] = [num_experts, hidden_size, expert_intermediate_size] + # Shared expert dense MLP + shapes[f"{hf_prefix}.mlp.gate_proj.weight"] = [intermediate_size, hidden_size] + shapes[f"{hf_prefix}.mlp.up_proj.weight"] = [intermediate_size, hidden_size] + shapes[f"{hf_prefix}.mlp.down_proj.weight"] = [hidden_size, intermediate_size] + else: + shapes[f"{hf_prefix}.mlp.gate_proj.weight"] = [intermediate_size, hidden_size] + shapes[f"{hf_prefix}.mlp.up_proj.weight"] = [intermediate_size, hidden_size] + shapes[f"{hf_prefix}.mlp.down_proj.weight"] = [hidden_size, intermediate_size] + + if vision_cfg: + vis_hidden = vision_cfg["hidden_size"] + vis_intermediate = vision_cfg["intermediate_size"] + vis_num_layers = vision_cfg["num_hidden_layers"] + vis_num_heads = vision_cfg["num_attention_heads"] + vis_head_dim = vision_cfg["head_dim"] + vis_q_dim = vis_num_heads * vis_head_dim + vis_kv_heads = vision_cfg.get("num_key_value_heads", vis_num_heads) + vis_kv_dim = vis_kv_heads * vis_head_dim + vis_pos_emb_size = vision_cfg.get("position_embedding_size", 10240) + vis_patch_size = vision_cfg.get("patch_size", 16) + num_channels = 3 # RGB + patch_flat = num_channels * vis_patch_size * vis_patch_size + + # VisionEntry: input_proj is a linear [patch_flat, vis_hidden] transposed to [vis_hidden, patch_flat] + shapes["model.vision_tower.patch_embedder.input_proj.weight"] = [vis_hidden, patch_flat] + # pos_emb_param MaxText shape (N, 2, D) -> transpose(1,0,2) -> HF (2, N, D) + shapes["model.vision_tower.patch_embedder.position_embedding_table"] = [2, vis_pos_emb_size, vis_hidden] + shapes["model.vision_tower.std_scale"] = [vis_hidden] + shapes["model.vision_tower.std_bias"] = [vis_hidden] + # Vision projector: [vis_hidden, hidden_size] -> reshape_kernel -> [hidden_size, vis_hidden] + shapes["model.embed_vision.embedding_projection.weight"] = [hidden_size, vis_hidden] + + for i in range(vis_num_layers): + vis_prefix = f"model.vision_tower.encoder.layers.{i}" + shapes[f"{vis_prefix}.self_attn.q_proj.linear.weight"] = [vis_q_dim, vis_hidden] + shapes[f"{vis_prefix}.self_attn.k_proj.linear.weight"] = [vis_kv_dim, vis_hidden] + shapes[f"{vis_prefix}.self_attn.v_proj.linear.weight"] = [vis_kv_dim, vis_hidden] + shapes[f"{vis_prefix}.self_attn.o_proj.linear.weight"] = [vis_hidden, vis_q_dim] + shapes[f"{vis_prefix}.self_attn.q_norm.weight"] = [vis_head_dim] + shapes[f"{vis_prefix}.self_attn.k_norm.weight"] = [vis_head_dim] + shapes[f"{vis_prefix}.input_layernorm.weight"] = [vis_hidden] + shapes[f"{vis_prefix}.post_attention_layernorm.weight"] = [vis_hidden] + shapes[f"{vis_prefix}.pre_feedforward_layernorm.weight"] = [vis_hidden] + shapes[f"{vis_prefix}.post_feedforward_layernorm.weight"] = [vis_hidden] + shapes[f"{vis_prefix}.mlp.gate_proj.linear.weight"] = [vis_intermediate, vis_hidden] + shapes[f"{vis_prefix}.mlp.up_proj.linear.weight"] = [vis_intermediate, vis_hidden] + shapes[f"{vis_prefix}.mlp.down_proj.linear.weight"] = [vis_hidden, vis_intermediate] + + return shapes + + def GEMMA2_HF_WEIGHTS_TO_SHAPE(config): """Returns mapping between HuggingFace weights path and weights shape. @@ -787,6 +918,8 @@ def MIXTRAL_HF_WEIGHTS_TO_SHAPE(config): "gemma3-4b": GEMMA3_HF_WEIGHTS_TO_SHAPE, "gemma3-12b": GEMMA3_HF_WEIGHTS_TO_SHAPE, "gemma3-27b": GEMMA3_HF_WEIGHTS_TO_SHAPE, + "gemma4-26b": GEMMA4_HF_WEIGHTS_TO_SHAPE, + "gemma4-31b": GEMMA4_HF_WEIGHTS_TO_SHAPE, "qwen2.5-1.5b": QWEN_HF_WEIGHTS_TO_SHAPE, "qwen2.5-7b": QWEN_HF_WEIGHTS_TO_SHAPE, "qwen2.5-14b": QWEN_HF_WEIGHTS_TO_SHAPE, diff --git a/src/maxtext/checkpoint_conversion/utils/param_mapping.py b/src/maxtext/checkpoint_conversion/utils/param_mapping.py index a8b1635933..219b63823b 100644 --- a/src/maxtext/checkpoint_conversion/utils/param_mapping.py +++ b/src/maxtext/checkpoint_conversion/utils/param_mapping.py @@ -2240,7 +2240,7 @@ def GEMMA4_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False "params-token_embedder-embedding": f"{text_base}.embed_tokens.weight", "params-decoder-decoder_norm-scale": f"{text_base}.norm.weight", } - if maxtext_config.use_multimodal and vcfg: + if vcfg: nvis = vcfg.get("num_hidden_layers", 0) mapping.update( { @@ -2577,6 +2577,14 @@ def reshape_moe_wo(input_tensor, target_shape): # input_tensor: [E, H, FF], target: [E, FF, H] return input_tensor.transpose(0, 2, 1) + def moe_gate_up_hook(weight_list, target_shape): + # Inverse of split_moe_wi0/wi1: fuse MaxText wi_0, wi_1 → HF experts.gate_up_proj. + # weight_list: [wi_0, wi_1], each [..., H, FF] + # Returns: [..., 2*FF, H] + wi_0 = jnp.asarray(weight_list[0]) + wi_1 = jnp.asarray(weight_list[1]) + return jnp.swapaxes(jnp.concatenate([wi_0, wi_1], axis=-1), -2, -1) + hooks["params-token_embedder-embedding"] = pad_hf_embedding_layer hooks["params-decoder-decoder_norm-scale"] = scale_rmsnorm_layer # REMOVED: logits_dense-kernel hook (handled by logits_via_embedding: True) @@ -2631,7 +2639,7 @@ def reshape_moe_wo(input_tensor, target_shape): # of norm_keys means they perfectly default to the identity mapping. vcfg = config.get("vision_config", {}) - if maxtext_config.use_multimodal and vcfg: + if vcfg: nvis = vcfg.get("num_hidden_layers", 0) def reshape_vision_patch(x, target_shape): @@ -2680,8 +2688,13 @@ def reshape_pos_emb(x, target_shape): hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer # Add these specialized 3D tensor hooks inside the loop - hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0 - hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1 + if saving_to_hf and num_experts > 1: + wi0_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0" + wi1_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1" + hooks[(wi0_key, wi1_key)] = moe_gate_up_hook + else: + hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0 + hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1 hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wo"] = reshape_moe_wo # Remainder sub-layer prefixes @@ -2700,8 +2713,13 @@ def reshape_pos_emb(x, target_shape): hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer # Add these specialized 3D tensor hooks inside the loop - hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0 - hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1 + if saving_to_hf and num_experts > 1: + wi0_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0" + wi1_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1" + hooks[(wi0_key, wi1_key)] = moe_gate_up_hook + else: + hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0 + hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1 hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wo"] = reshape_moe_wo else: for i in range(nlayers): @@ -2717,8 +2735,13 @@ def reshape_pos_emb(x, target_shape): hooks[f"{prefix}-{key}"] = scale_rmsnorm_layer # Add these specialized 3D tensor hooks inside the loop - hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0 - hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1 + if saving_to_hf and num_experts > 1: + wi0_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0" + wi1_key = f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1" + hooks[(wi0_key, wi1_key)] = moe_gate_up_hook + else: + hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_0"] = split_moe_wi0 + hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wi_1"] = split_moe_wi1 hooks[f"{prefix}-mlp-moe_block-MoeBlock_0-wo"] = reshape_moe_wo return hooks diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index ec5c3e657f..26724fe863 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -99,11 +99,19 @@ def validate_and_filter_param_map_keys(param_map_keys, maxtext_state_keys): # 1 Validate: every maxtext state key must be covered by param map missing_keys = maxtext_state_keys - flattened_map_keys if missing_keys: + hint = "" + ckpt_has_scanned = any("scanned_blocks" in k for k in missing_keys) + map_has_scanned = any("scanned_blocks" in k for k in flattened_map_keys) + if ckpt_has_scanned and not map_has_scanned: + hint = "\nHint: checkpoint keys contain 'scanned_blocks' but param_map does not — try scan_layers=True." + elif map_has_scanned and not ckpt_has_scanned: + hint = "\nHint: param_map contains 'scanned_blocks' keys but checkpoint does not — try scan_layers=False." raise ValueError( "maxtext_state_dict must be a subset of flattened param_map" + f"\nparam map\n{param_map_keys}" + f"\nmaxtext:\n{maxtext_state_keys}" + f"\nmissing keys:\n{missing_keys}" + + hint ) # 2 Filter: param map may have extra keys @@ -172,9 +180,7 @@ def _process(hf_path, processed_slice, output_weights, current_hook_fns, hf_shap # If hook is unsepecified, use identity if current_hook_fns: processed_slice = apply_hook_fns(processed_slice, target_hf_shape, current_hook_fns) - numpy_slice = convert_jax_weight_to_numpy(processed_slice, save_dtype).squeeze() - if numpy_slice.shape != tuple(target_hf_shape): - raise ValueError(f"Shape mismatch for {hf_path}: Expect {target_hf_shape}, got {numpy_slice.shape}") + numpy_slice = convert_jax_weight_to_numpy(processed_slice, save_dtype).reshape(target_hf_shape) output_weights.append((hf_path, numpy_slice)) diff --git a/tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh b/tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh index da23fac8fe..a3785dcfdf 100644 --- a/tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh +++ b/tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh @@ -4,7 +4,7 @@ set -ex idx=$(date +%Y-%m-%d-%H-%M) MODEL_NAME='gemma4-26b' -export MODEL_VARIATION='26b' +export MODEL_VARIATION='26b-it' TOKENIZER_PATH='google/gemma-4-26b-a4b-it' # To convert the multimodal model, make sure the use_multimodal is set to be true USE_MULTIMODAL=false @@ -30,7 +30,9 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MA export MAXTEXT_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/converted/${idx}/0/items - +# Run forward pass logit checker to validate the converted checkpoint. +# The *_tile_fwd_*_dim flags are for reducing vmem usage to fit into v5p chips, +# not for performance purpose. if [ ${USE_MULTIMODAL} == true ]; then # Set the shared multimodal prompt and image TEST_PROMPT='Describe image <|image|>' @@ -55,13 +57,13 @@ if [ ${USE_MULTIMODAL} == true ]; then dtype=float32 \ matmul_precision=highest \ per_device_batch_size=1 \ - vision_output_length=280 \ attention=dot_product \ prompt="${TEST_PROMPT}" \ image_path=${TEST_IMAGE} \ --max_kl_div=0.03 \ --golden_logits_path=${GOLDEN_LOGITS_PATH} else + echo "=== Running MaxText Forward Pass Logit Checker ===" python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${MAXTEXT_CKPT_PATH} \ @@ -70,6 +72,10 @@ else scan_layers=${USE_SCAN_LAYERS} \ per_device_batch_size=1 \ dtype=float32 \ + wi_tile_fwd_embed_dim=512 \ + wi_tile_fwd_mlp_dim=512 \ + wo_tile_fwd_embed_dim=512 \ + wo_tile_fwd_mlp_dim=512 \ --max_kl_div=0.03 \ --run_hf_model=true \ --hf_model_path=${HF_MODEL} diff --git a/tests/end_to_end/tpu/gemma4/26b/convert_gemma4_pt.sh b/tests/end_to_end/tpu/gemma4/26b/convert_gemma4_pt.sh index 4a24621342..40b49b84b8 100644 --- a/tests/end_to_end/tpu/gemma4/26b/convert_gemma4_pt.sh +++ b/tests/end_to_end/tpu/gemma4/26b/convert_gemma4_pt.sh @@ -4,7 +4,7 @@ set -ex idx=$(date +%Y-%m-%d-%H-%M) MODEL_NAME='gemma4-26b' -export MODEL_VARIATION='26b' +export MODEL_VARIATION='26b-pt' TOKENIZER_PATH='google/gemma-4-26b-a4b' # To convert the multimodal model, make sure the use_multimodal is set to be true USE_MULTIMODAL=false @@ -29,7 +29,9 @@ python3 -m maxtext.checkpoint_conversion.to_maxtext "${MAXTEXT_CONFIGS_DIR:-${MA export MAXTEXT_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/converted/${idx}/0/items - +# Run forward pass logit checker to validate the converted checkpoint. +# The *_tile_fwd_*_dim flags are for reducing vmem usage to fit into v5p chips, +# not for performance purpose. if [ ${USE_MULTIMODAL} == true ]; then # Set the shared multimodal prompt and image TEST_PROMPT='Describe image <|image|>' @@ -54,13 +56,17 @@ if [ ${USE_MULTIMODAL} == true ]; then dtype=float32 \ matmul_precision=highest \ per_device_batch_size=1 \ - vision_output_length=280 \ attention=dot_product \ + wi_tile_fwd_embed_dim=512 \ + wi_tile_fwd_mlp_dim=512 \ + wo_tile_fwd_embed_dim=512 \ + wo_tile_fwd_mlp_dim=512 \ prompt="${TEST_PROMPT}" \ image_path=${TEST_IMAGE} \ --max_kl_div=0.03 \ --golden_logits_path=${GOLDEN_LOGITS_PATH} else + echo "=== Running MaxText Forward Pass Logit Checker ===" python3 -m tests.utils.forward_pass_logit_checker "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ tokenizer_path=${TOKENIZER_PATH} \ load_parameters_path=${MAXTEXT_CKPT_PATH} \ @@ -69,6 +75,10 @@ else scan_layers=${USE_SCAN_LAYERS} \ per_device_batch_size=1 \ dtype=float32 \ + wi_tile_fwd_embed_dim=512 \ + wi_tile_fwd_mlp_dim=512 \ + wo_tile_fwd_embed_dim=512 \ + wo_tile_fwd_mlp_dim=512 \ --max_kl_div=0.03 \ --run_hf_model=true \ --hf_model_path=${HF_MODEL} diff --git a/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_hf.sh b/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_hf.sh new file mode 100755 index 0000000000..476aed964a --- /dev/null +++ b/tests/end_to_end/tpu/gemma4/26b/test_gemma4_to_hf.sh @@ -0,0 +1,103 @@ +#!/bin/bash + +# This script is both an end-to-end test and documentation for converting a +# Gemma4-26B MaxText checkpoint to Hugging Face format. Can be run on a v5p-8. + +# The flow of this script is as follows: +# 1. Convert a MaxText checkpoint to a Hugging Face model checkpoint. +# 2. Run a forward pass check to compare the logits and KL divergence between +# the converted checkpoint and the original HF model. + +# Pre-requisites: +# 1. Set HF_TOKEN environment variable to your Hugging Face access token. +# export HF_TOKEN= +# 2. Provide a MaxText-format Gemma4-26B checkpoint via CKPT_PATH. +# One can be produced with tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh. + +set -ex +idx=$(date +%Y-%m-%d-%H-%M) +MODEL_NAME='gemma4-26b' +export MODEL_VARIATION='26b-it' +# To convert the multimodal model, set USE_MULTIMODAL=true +USE_MULTIMODAL=false +# Set USE_SCAN_LAYERS=true if the checkpoint was trained with scanned layers +USE_SCAN_LAYERS=true + +# Installing torch for deps in forward_pass_logit_checker.py +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Non-Googlers: point MODEL_BUCKET to a GCS bucket you own. +export MODEL_BUCKET=gs://maxtext-gemma/gemma4 +# Path to a pre-existing MaxText checkpoint for gemma4-26b. Must match USE_SCAN_LAYERS. +# Run tests/end_to_end/tpu/gemma4/26b/convert_gemma4.sh to produce one. +export CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/converted/unscanned/0/items + +# Path to the original HF model weights for logit comparison. +export HF_MODEL=google/gemma-4-26b-a4b-it + +export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx} + +python3 -m maxtext.checkpoint_conversion.to_huggingface \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + model_name=${MODEL_NAME} \ + hf_access_token=${HF_TOKEN} \ + load_parameters_path=${CKPT_PATH} \ + base_output_directory=${LOCAL_PATH} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=${USE_SCAN_LAYERS} + +# Run forward pass logit checker to validate the converted checkpoint. +# The *_tile_fwd_*_dim flags are for reducing vmem usage to fit into v5p chips, +# not for performance purpose. +if [ "${USE_MULTIMODAL}" == true ]; then + TEST_PROMPT='Describe image <|image|>' + TEST_IMAGE='tests/assets/test_image.jpg' + export GOLDEN_LOGITS_PATH=/tmp/golden_gemma4_26b_vision.pickle + + python3 -m tests.assets.logits_generation.generate_hf_golden_logits \ + --model-id=${HF_MODEL} \ + --output-path=${GOLDEN_LOGITS_PATH} \ + --prompts="${TEST_PROMPT}" \ + --image-paths=${TEST_IMAGE} \ + --hf-model-path=${LOCAL_PATH} \ + --output-format=pickle + + echo "=== Running MaxText Forward Pass Logit Checker ===" + python3 -m tests.utils.forward_pass_logit_checker \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + tokenizer_path=${HF_MODEL} \ + load_parameters_path=${CKPT_PATH} \ + model_name=${MODEL_NAME} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=${USE_SCAN_LAYERS} \ + dtype=float32 \ + wi_tile_fwd_embed_dim=512 \ + wi_tile_fwd_mlp_dim=512 \ + wo_tile_fwd_embed_dim=512 \ + wo_tile_fwd_mlp_dim=512 \ + matmul_precision=highest \ + per_device_batch_size=1 \ + attention=dot_product \ + prompt="${TEST_PROMPT}" \ + image_path=${TEST_IMAGE} \ + --max_kl_div=0.1 \ + --golden_logits_path=${GOLDEN_LOGITS_PATH} +else + echo "=== Running MaxText Forward Pass Logit Checker ===" + python3 -m tests.utils.forward_pass_logit_checker \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + tokenizer_path=${HF_MODEL} \ + load_parameters_path=${CKPT_PATH} \ + model_name=${MODEL_NAME} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=${USE_SCAN_LAYERS} \ + per_device_batch_size=1 \ + dtype=float32 \ + wi_tile_fwd_embed_dim=512 \ + wi_tile_fwd_mlp_dim=512 \ + wo_tile_fwd_embed_dim=512 \ + wo_tile_fwd_mlp_dim=512 \ + --max_kl_div=0.1 \ + --run_hf_model=true \ + --hf_model_path=${LOCAL_PATH} +fi diff --git a/tests/end_to_end/tpu/gemma4/31b/convert_gemma4.sh b/tests/end_to_end/tpu/gemma4/31b/convert_gemma4.sh index 6a78ef4c32..31ed580773 100644 --- a/tests/end_to_end/tpu/gemma4/31b/convert_gemma4.sh +++ b/tests/end_to_end/tpu/gemma4/31b/convert_gemma4.sh @@ -4,7 +4,7 @@ set -ex idx=$(date +%Y-%m-%d-%H-%M) MODEL_NAME='gemma4-31b' -export MODEL_VARIATION='31b' +export MODEL_VARIATION='31b-it' TOKENIZER_PATH='google/gemma-4-31b-it' # To convert the multimodal model, make sure the use_multimodal is set to be true USE_MULTIMODAL=false diff --git a/tests/end_to_end/tpu/gemma4/31b/convert_gemma4_pt.sh b/tests/end_to_end/tpu/gemma4/31b/convert_gemma4_pt.sh index 664ec0df60..68802a390b 100644 --- a/tests/end_to_end/tpu/gemma4/31b/convert_gemma4_pt.sh +++ b/tests/end_to_end/tpu/gemma4/31b/convert_gemma4_pt.sh @@ -4,7 +4,7 @@ set -ex idx=$(date +%Y-%m-%d-%H-%M) MODEL_NAME='gemma4-31b' -export MODEL_VARIATION='31b' +export MODEL_VARIATION='31b-pt' TOKENIZER_PATH='google/gemma-4-31b' # To convert the multimodal model, make sure the use_multimodal is set to be true USE_MULTIMODAL=false diff --git a/tests/end_to_end/tpu/gemma4/31b/test_gemma4_to_hf.sh b/tests/end_to_end/tpu/gemma4/31b/test_gemma4_to_hf.sh new file mode 100755 index 0000000000..5e12ca51a1 --- /dev/null +++ b/tests/end_to_end/tpu/gemma4/31b/test_gemma4_to_hf.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +# This script is both an end-to-end test and documentation for converting a +# Gemma4-31B MaxText checkpoint to Hugging Face format. Can be run on a v5p-8. + +# The flow of this script is as follows: +# 1. Convert a MaxText checkpoint to a Hugging Face model checkpoint. +# 2. Run a forward pass check to compare the logits and KL divergence between +# the converted checkpoint and the original HF model. + +# Pre-requisites: +# 1. Set HF_TOKEN environment variable to your Hugging Face access token. +# export HF_TOKEN= +# 2. Provide a MaxText-format Gemma4-31B checkpoint via CKPT_PATH. +# One can be produced with tests/end_to_end/tpu/gemma4/31b/convert_gemma4.sh. + +set -ex +idx=$(date +%Y-%m-%d-%H-%M) +MODEL_NAME='gemma4-31b' +export MODEL_VARIATION='31b-it' +# To convert the multimodal model, set USE_MULTIMODAL=true +USE_MULTIMODAL=false +# Set USE_SCAN_LAYERS=true if the checkpoint was trained with scanned layers +USE_SCAN_LAYERS=false + +# Installing torch for deps in forward_pass_logit_checker.py +python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu + +# Non-Googlers: point MODEL_BUCKET to a GCS bucket you own. +export MODEL_BUCKET=gs://maxtext-gemma/gemma4 +# Path to a pre-existing MaxText checkpoint for gemma4-31b. Must match USE_SCAN_LAYERS. +# Run tests/end_to_end/tpu/gemma4/31b/convert_gemma4.sh to produce one. +export CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/converted/unscanned/0/items +# Path to the original HF model weights for logit comparison. +export HF_MODEL=google/gemma-4-31b-it + +export LOCAL_PATH=./tmp/hf/${MODEL_NAME}/${idx} + +python3 -m maxtext.checkpoint_conversion.to_huggingface \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + model_name=${MODEL_NAME} \ + hf_access_token=${HF_TOKEN} \ + load_parameters_path=${CKPT_PATH} \ + base_output_directory=${LOCAL_PATH} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=${USE_SCAN_LAYERS} + +# Run forward pass logit checker to validate the converted checkpoint. +# The *_tile_fwd_*_dim flags are for reducing vmem usage to fit into v5p chips, +# not for performance purpose. +if [ "${USE_MULTIMODAL}" == true ]; then + TEST_PROMPT='Describe image <|image|>' + TEST_IMAGE='tests/assets/test_image.jpg' + export GOLDEN_LOGITS_PATH=/tmp/golden_gemma4_31b_vision.pickle + + python3 -m tests.assets.logits_generation.generate_hf_golden_logits \ + --model-id=${HF_MODEL} \ + --output-path=${GOLDEN_LOGITS_PATH} \ + --prompts="${TEST_PROMPT}" \ + --image-paths=${TEST_IMAGE} \ + --hf-model-path=${LOCAL_PATH} \ + --output-format=pickle + + echo "=== Running MaxText Forward Pass Logit Checker ===" + python3 -m tests.utils.forward_pass_logit_checker \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + tokenizer_path=${HF_MODEL} \ + load_parameters_path=${CKPT_PATH} \ + model_name=${MODEL_NAME} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=${USE_SCAN_LAYERS} \ + dtype=float32 \ + matmul_precision=highest \ + per_device_batch_size=1 \ + attention=dot_product \ + prompt="${TEST_PROMPT}" \ + image_path=${TEST_IMAGE} \ + --max_kl_div=0.1 \ + --golden_logits_path=${GOLDEN_LOGITS_PATH} +else + echo "=== Running MaxText Forward Pass Logit Checker ===" + python3 -m tests.utils.forward_pass_logit_checker \ + "${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}"/base.yml \ + tokenizer_path=${HF_MODEL} \ + load_parameters_path=${CKPT_PATH} \ + model_name=${MODEL_NAME} \ + use_multimodal=${USE_MULTIMODAL} \ + scan_layers=${USE_SCAN_LAYERS} \ + per_device_batch_size=1 \ + dtype=float32 \ + --max_kl_div=0.1 \ + --run_hf_model=true \ + --hf_model_path=${LOCAL_PATH} +fi