Skip to content

Add Gemma4 MaxText→vLLM Weight Converter (torchax)#3794

Merged
copybara-service[bot] merged 1 commit intomainfrom
hengtaoguo-gemma4
May 6, 2026
Merged

Add Gemma4 MaxText→vLLM Weight Converter (torchax)#3794
copybara-service[bot] merged 1 commit intomainfrom
hengtaoguo-gemma4

Conversation

@hengtaoguo
Copy link
Copy Markdown
Collaborator

@hengtaoguo hengtaoguo commented May 1, 2026

Description

Original author: @khatwanimohit @aireenmei #3677
Refactor/test: @hengtaoguo

Extracts the Gemma4-specific weight conversion logic from bench_weight_sync.py into a proper converter class, following the same base/model-specific split established for Qwen3.

New: gemma4_moe.py

  • Adds Gemma4MaxTextToVLLMConverter, inheriting from BaseMaxTextToVLLMConverter
  • Supports gemma4-26b (MoE: 128 routed + 1 shared expert)
  • Handles Gemma4's scanned-block layout (6 slots × N reps, local + global attention)
  • Overrides convert() to add the _convert_norms step and dispatch MoE vs. dense MLP

Updated: validate_converter.py

  • Imports Gemma4MaxTextToVLLMConverter and dispatches on gemma4-* model names
  • Adds gemma4-26b entry to vllm_model_name_mapping

Notes:

  1. Set env var MODEL_IMPL_TYPE=vllm to force the torchax-backed vLLM model for Gemma4 (default "auto" resolves to "flax_nnx" in newer tpu-inference, which uses a nested Flax state incompatible with the flat-key converter output)
  2. Gemma4's prompt need to start with <bos>, example: prompt="<box>Paris is"

Tests

Tested with validate_converter, full logs:

export MODEL_IMPL_TYPE="vllm"
python -m maxtext.integration.vllm.torchax_converter.validate_converter src/maxtext/configs/base.yml model_name=gemma4-26b tokenizer_type=huggingface tokenizer_path=google/gemma-4-26b-a4b-it load_parameters_path=gs://maxtext-gemma/gemma4/26b/converted/2026-04-07-23-04/0/items run_name=gemma4_converter_validation per_device_batch_size=1 max_prefill_predict_length=8 max_target_length=16 steps=1 scan_layers=true skip_jax_distributed_system=true weight_dtype=bfloat16 attention=dot_product remat_policy=custom decoder_layer_input=offload query_proj=offload key_proj=offload value_proj=offload ici_expert_parallelism=4 rollout_tensor_parallelism=4 hbm_utilization_vllm=0.8 async_scheduling=false prompt=\<bos\>Paris\ is hf_access_token=xxx 

# Also works with pathways ckpt at gs://hengtaoguo-maxtext-logs/checkpoints/gemma4-26b/scanned/gemma-4-26B-A4B-pathways/0/items
# python -m maxtext.checkpoint_conversion.to_maxtext src/maxtext/configs/base.yml model_name=gemma4-26b base_output_directory=gs://hengtaoguo-maxtext-logs/checkpoints/gemma4-26b/scanned/gemma-4-26B-A4B-pathways scan_layers=true hf_access_token=xxx weight_dtype=bfloat16 hardware=cpu skip_jax_distributed_system=True checkpoint_storage_use_ocdbt=False checkpoint_storage_use_zarr3=False
[RequestOutput(request_id=0, prompt='<bos>Paris is', prompt_token_ids=[2, 2, 50429, 563], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text=' the capital of France. It is accommodated in the', token_ids=[506, 5279, 529, 7001, 236761, 1030, 563, 98509, 528, 506], routed_experts=None, cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=None, lora_request=None, num_cached_tokens=0)]

Tuned checkpoint with chat template use_chat_template=True:

export MODEL_IMPL_TYPE="vllm"
python -m maxtext.integration.vllm.torchax_converter.validate_converter src/maxtext/configs/base.yml model_name=gemma4-26b tokenizer_type=huggingface tokenizer_path=google/gemma-4-26b-a4b-it load_parameters_path=gs://hengtaoguo-maxtext-logs/checkpoints/gemma4-26b/scanned/gemma-4-26B-A4B-it/0/items run_name=gemma4_converter_validation per_device_batch_size=1 max_prefill_predict_length=8 max_target_length=32 steps=1 scan_layers=true skip_jax_distributed_system=true weight_dtype=bfloat16 attention=dot_product remat_policy=custom decoder_layer_input=offload query_proj=offload key_proj=offload value_proj=offload ici_expert_parallelism=4 rollout_tensor_parallelism=4 hbm_utilization_vllm=0.8 async_scheduling=false prompt=\<bos\>Paris\ is hf_access_token=xxx use_chat_template=True
[RequestOutput(request_id=0, prompt='<bos><|turn>user\nWhere is Paris?<turn|>\n<|turn>model\n<|channel>thought\n<channel|>', prompt_token_ids=[2, 2, 105, 2364, 107, 10936, 563, 9079, 236881, 106, 107, 105, 4368, 107, 100, 45518, 107, 101], encoder_prompt=None, encoder_prompt_token_ids=None, prompt_logprobs=None, outputs=[CompletionOutput(index=0, text='"Paris" can-refer-to\'s own context depends on', token_ids=[236775, 50429, 236775, 740, 236772, 35380, 236772, 1071, 236789, 236751, 1852, 4403, 9796, 580], routed_experts=None, cumulative_logprob=None, logprobs=None, finish_reason=length, stop_reason=None)], finished=True, metrics=None, lora_request=None, num_cached_tokens=0)]

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 1, 2026

🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces a weight converter for the Gemma4 model to facilitate MaxText to vLLM conversion, specifically for the gemma4-26b MoE variant. The implementation is well-structured and follows the established patterns for weight conversion in the project, including proper JIT usage and memory management.

🔍 General Feedback

  • Naming Consistency: There is a critical naming discrepancy in the MoE weight keys (moe.per_expert_scale vs router.per_expert_scale) that should be resolved to ensure compatibility with vLLM's expectation.
  • Code Cleanup: A few minor items like unused arguments in JIT functions and commented-out debugging code in the validator should be addressed to maintain code quality.
  • Performance: The use of jax.jit for batch weight processing is a good practice for minimizing conversion overhead.

Comment thread src/maxtext/integration/vllm/torchax_converter/validate_converter.py Outdated
Comment thread src/maxtext/integration/vllm/torchax_converter/gemma4_moe.py
Comment thread src/maxtext/integration/vllm/torchax_converter/gemma4_moe.py
Comment thread src/maxtext/integration/vllm/torchax_converter/gemma4_moe.py
Comment thread src/maxtext/integration/vllm/torchax_converter/validate_converter.py Outdated
Comment thread src/maxtext/integration/vllm/torchax_converter/validate_converter.py Outdated
@hengtaoguo hengtaoguo force-pushed the hengtaoguo-gemma4 branch from 1859f35 to 721a9ab Compare May 5, 2026 19:54
@codecov
Copy link
Copy Markdown

codecov Bot commented May 5, 2026

Codecov Report

❌ Patch coverage is 0% with 195 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...t/integration/vllm/torchax_converter/gemma4_moe.py 0.00% 177 Missing ⚠️
...ation/vllm/torchax_converter/validate_converter.py 0.00% 18 Missing ⚠️

📢 Thoughts on this report? Let us know!

@hengtaoguo hengtaoguo force-pushed the hengtaoguo-gemma4 branch 4 times, most recently from d55189f to 434727b Compare May 5, 2026 23:28
@hengtaoguo hengtaoguo force-pushed the hengtaoguo-gemma4 branch from a939844 to bcfa1b1 Compare May 5, 2026 23:36
@khatwanimohit khatwanimohit removed their assignment May 6, 2026
@copybara-service copybara-service Bot merged commit 7d6e1ca into main May 6, 2026
32 of 33 checks passed
@copybara-service copybara-service Bot deleted the hengtaoguo-gemma4 branch May 6, 2026 00:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants