From a7d1170cf43615027708abe0347152e8062b0cb0 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Date: Sat, 14 Feb 2026 01:54:53 +0000 Subject: [PATCH 1/5] [Megatron Export] Add Qwen3-VL export/import mapping Add Megatron Core export/import mapping for Qwen3-VL (Qwen3VLForConditionalGeneration). Handles the model.language_model. weight prefix and supports both dense and MoE variants. Signed-off-by: Hung-Yueh mv test_mcore_qwen3vl.py to tests/gpu_megatron/torch/export/ Signed-off-by: Hung-Yueh Chiang --- CHANGELOG.rst | 1 + docs/source/deployment/3_unified_hf.rst | 1 + modelopt/torch/export/plugins/mcore_common.py | 6 + .../torch/export/plugins/mcore_qwen3vl.py | 120 +++++++ .../torch/export/test_mcore_qwen3vl.py | 306 ++++++++++++++++++ 5 files changed, 434 insertions(+) create mode 100644 modelopt/torch/export/plugins/mcore_qwen3vl.py create mode 100644 tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 62f2b0041cb..8fd414ec5fc 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -24,6 +24,7 @@ Changelog - Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md `_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model. - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. +- Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL and supports both dense and MoE variants. 0.44 (2026-05-18) ^^^^^^^^^^^^^^^^^ diff --git a/docs/source/deployment/3_unified_hf.rst b/docs/source/deployment/3_unified_hf.rst index 9124164b576..6664f987f72 100644 --- a/docs/source/deployment/3_unified_hf.rst +++ b/docs/source/deployment/3_unified_hf.rst @@ -61,6 +61,7 @@ Models: * Llama 4, 3.x (FP8, NVFP4) * Qwen 3, 2.5 (FP8, NVFP4) * Qwen 3 MoE (FP8, NVFP4) + * Qwen 3-VL (FP8, NVFP4) * Deepseek R1/V3 (NVFP4) * Mixtral 8x7B (FP8, NVFP4) * Medusa (FP8) diff --git a/modelopt/torch/export/plugins/mcore_common.py b/modelopt/torch/export/plugins/mcore_common.py index d5bab9b4ece..660e4eac96d 100644 --- a/modelopt/torch/export/plugins/mcore_common.py +++ b/modelopt/torch/export/plugins/mcore_common.py @@ -39,6 +39,10 @@ qwen25_causal_lm_export, qwen25_causal_lm_import, ) +from .mcore_qwen3vl import ( + qwen3vl_causal_lm_export, + qwen3vl_causal_lm_import, +) all_mcore_hf_export_mapping: dict[str, Any] = { "DeepseekV2ForCausalLM": deepseek_causal_lm_export, @@ -54,6 +58,7 @@ "Qwen3MoeForCausalLM": qwen3_causal_lm_export, "Qwen2ForCausalLM": qwen25_causal_lm_export, "GptOssForCausalLM": gptoss_causal_lm_export, + "Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_export, } all_mcore_hf_import_mapping: dict[str, Any] = { @@ -66,4 +71,5 @@ "Qwen3MoeForCausalLM": qwen3_causal_lm_import, "Qwen2ForCausalLM": qwen25_causal_lm_import, "GptOssForCausalLM": gptoss_causal_lm_import, + "Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_import, } diff --git a/modelopt/torch/export/plugins/mcore_qwen3vl.py b/modelopt/torch/export/plugins/mcore_qwen3vl.py new file mode 100644 index 00000000000..40eb99adb50 --- /dev/null +++ b/modelopt/torch/export/plugins/mcore_qwen3vl.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models. + +Qwen3-VL model structure differs from Qwen3: +- Language model weights are under `model.language_model.` prefix +- Visual encoder weights are under `model.visual.` prefix + +This module handles the language model conversion for PTQ/QAT workflows. +Visual components are typically kept in full precision. + +HuggingFace Qwen3-VL-8B structure: +- model.language_model.embed_tokens.weight +- model.language_model.layers.{L}.input_layernorm.weight +- model.language_model.layers.{L}.self_attn.q_proj.weight +- model.language_model.layers.{L}.self_attn.k_proj.weight +- model.language_model.layers.{L}.self_attn.v_proj.weight +- model.language_model.layers.{L}.self_attn.q_norm.weight +- model.language_model.layers.{L}.self_attn.k_norm.weight +- model.language_model.layers.{L}.self_attn.o_proj.weight +- model.language_model.layers.{L}.post_attention_layernorm.weight +- model.language_model.layers.{L}.mlp.gate_proj.weight +- model.language_model.layers.{L}.mlp.up_proj.weight +- model.language_model.layers.{L}.mlp.down_proj.weight +- model.language_model.norm.weight +- lm_head.weight +""" + +from .mcore_custom import ( + COL_ETP, + COL_TP, + REPLICATE, + ROW_ETP, + ROW_TP, + CustomModuleMapping, + GatedMLPMerging, + GatedMLPSlicing, + NameRemapping, + QKVMerging, + QKVSlicing, +) + +# Import rules: HuggingFace -> Megatron Core +qwen3vl_causal_lm_import: dict[str, CustomModuleMapping] = { + # Embeddings - note the language_model prefix + "word_embeddings": NameRemapping("model.language_model.embed_tokens.", COL_TP), + # Final layer norm + "final_layernorm": NameRemapping("model.language_model.norm.", REPLICATE), + # Output layer (lm_head is at root level, not under language_model) + "output_layer": NameRemapping("lm_head.", COL_TP), + # Attention - input layernorm + "input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm.", REPLICATE), + # Attention - QKV projection (merged) + "linear_qkv": QKVMerging("model.language_model.layers.{}.self_attn.", COL_TP), + # Attention - output projection + "linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj.", ROW_TP), + # Attention - Q/K layer norms (Qwen3 uses RMSNorm on Q and K) + "q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm.", REPLICATE), + "k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm.", REPLICATE), + # MLP - pre-MLP layernorm (post_attention_layernorm in HF) + "pre_mlp_layernorm": NameRemapping( + "model.language_model.layers.{}.post_attention_layernorm.", REPLICATE + ), + # MLP - gate_proj + up_proj merged into linear_fc1 + "linear_fc1": GatedMLPMerging("model.language_model.layers.{}.mlp.", COL_TP), + # MLP - down_proj as linear_fc2 + "linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj.", ROW_TP), + # MoE support (for Qwen3-VL MoE variants like 30B-A3B) + "router": NameRemapping("model.language_model.layers.{}.mlp.gate.", REPLICATE), + "local_experts.linear_fc1": GatedMLPMerging( + "model.language_model.layers.{}.mlp.experts.{}.", COL_ETP + ), + "local_experts.linear_fc2": NameRemapping( + "model.language_model.layers.{}.mlp.experts.{}.down_proj.", ROW_ETP + ), +} + +# Export rules: Megatron Core -> HuggingFace +qwen3vl_causal_lm_export: dict[str, CustomModuleMapping] = { + # Embeddings + "word_embeddings": NameRemapping("model.language_model.embed_tokens."), + # Final layer norm + "final_layernorm": NameRemapping("model.language_model.norm."), + # Output layer + "output_layer": NameRemapping("lm_head."), + # Attention - input layernorm + "input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm."), + # Attention - QKV projection (sliced back to separate q/k/v) + "linear_qkv": QKVSlicing("model.language_model.layers.{}.self_attn."), + # Attention - output projection + "linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj."), + # Attention - Q/K layer norms + "q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm."), + "k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm."), + # MLP - pre-MLP layernorm + "pre_mlp_layernorm": NameRemapping("model.language_model.layers.{}.post_attention_layernorm."), + # MLP - linear_fc1 sliced back to gate_proj + up_proj + "linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp."), + # MLP - down_proj + "linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj."), + # MoE support + "router": NameRemapping("model.language_model.layers.{}.mlp.gate."), + "local_experts.linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp.experts.{}."), + "local_experts.linear_fc2": NameRemapping( + "model.language_model.layers.{}.mlp.experts.{}.down_proj." + ), +} \ No newline at end of file diff --git a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py new file mode 100644 index 00000000000..3f57cb9c478 --- /dev/null +++ b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py @@ -0,0 +1,306 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Qwen3-VL Megatron Core export/import plugin.""" + +import pytest + +from modelopt.torch.export.plugins.mcore_custom import ( + COL_TP, + REPLICATE, + ROW_TP, + GatedMLPMerging, + GatedMLPSlicing, + NameRemapping, + QKVMerging, + QKVSlicing, +) +from modelopt.torch.export.plugins.mcore_qwen3vl import ( + qwen3vl_causal_lm_export, + qwen3vl_causal_lm_import, +) + + +# All mcore keys that a dense (non-MoE) Qwen3-VL model should have +DENSE_MCORE_KEYS = { + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", +} + +# Additional MoE keys +MOE_MCORE_KEYS = { + "router", + "local_experts.linear_fc1", + "local_experts.linear_fc2", +} + + +class TestQwen3VLRegistration: + """Test that Qwen3-VL is registered in the global mapping.""" + + def test_registered_in_export_mapping(self): + from modelopt.torch.export.plugins.mcore_common import ( + all_mcore_hf_export_mapping, + ) + + assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_export_mapping + assert ( + all_mcore_hf_export_mapping["Qwen3VLForConditionalGeneration"] + is qwen3vl_causal_lm_export + ) + + def test_registered_in_import_mapping(self): + from modelopt.torch.export.plugins.mcore_common import ( + all_mcore_hf_import_mapping, + ) + + assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_import_mapping + assert ( + all_mcore_hf_import_mapping["Qwen3VLForConditionalGeneration"] + is qwen3vl_causal_lm_import + ) + + +class TestQwen3VLImportMapping: + """Test the HuggingFace -> Megatron Core import mapping.""" + + def test_has_all_dense_keys(self): + assert DENSE_MCORE_KEYS.issubset(qwen3vl_causal_lm_import.keys()) + + def test_has_all_moe_keys(self): + assert MOE_MCORE_KEYS.issubset(qwen3vl_causal_lm_import.keys()) + + def test_language_model_prefix(self): + """Qwen3-VL uses model.language_model. prefix (not model.).""" + prefix_keys = [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + ] + for key in prefix_keys: + mapping = qwen3vl_causal_lm_import[key] + assert "model.language_model." in mapping.target_name_or_prefix, ( + f"{key}: expected 'model.language_model.' prefix, " + f"got '{mapping.target_name_or_prefix}'" + ) + + def test_output_layer_at_root(self): + """lm_head is at root level, not under language_model.""" + mapping = qwen3vl_causal_lm_import["output_layer"] + assert mapping.target_name_or_prefix == "lm_head." + + def test_qkv_uses_merging(self): + assert isinstance(qwen3vl_causal_lm_import["linear_qkv"], QKVMerging) + + def test_mlp_uses_gated_merging(self): + assert isinstance( + qwen3vl_causal_lm_import["linear_fc1"], GatedMLPMerging + ) + + @pytest.mark.parametrize( + "key", + [ + "input_layernorm", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "final_layernorm", + ], + ) + def test_layernorms_are_replicated(self, key): + """Layernorms should use REPLICATE (empty func_kwargs).""" + mapping = qwen3vl_causal_lm_import[key] + assert isinstance(mapping, NameRemapping) + assert mapping.func_kwargs == REPLICATE + + @pytest.mark.parametrize( + "key,expected_kwargs", + [ + ("word_embeddings", COL_TP), + ("output_layer", COL_TP), + ("linear_proj", ROW_TP), + ], + ) + def test_tp_sharding(self, key, expected_kwargs): + mapping = qwen3vl_causal_lm_import[key] + assert mapping.func_kwargs == expected_kwargs + + +class TestQwen3VLExportMapping: + """Test the Megatron Core -> HuggingFace export mapping.""" + + def test_has_all_dense_keys(self): + assert DENSE_MCORE_KEYS.issubset(qwen3vl_causal_lm_export.keys()) + + def test_has_all_moe_keys(self): + assert MOE_MCORE_KEYS.issubset(qwen3vl_causal_lm_export.keys()) + + def test_language_model_prefix(self): + """Export paths should also use model.language_model. prefix.""" + prefix_keys = [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + ] + for key in prefix_keys: + mapping = qwen3vl_causal_lm_export[key] + assert "model.language_model." in mapping.target_name_or_prefix, ( + f"{key}: expected 'model.language_model.' prefix, " + f"got '{mapping.target_name_or_prefix}'" + ) + + def test_output_layer_at_root(self): + mapping = qwen3vl_causal_lm_export["output_layer"] + assert mapping.target_name_or_prefix == "lm_head." + + def test_qkv_uses_slicing(self): + assert isinstance(qwen3vl_causal_lm_export["linear_qkv"], QKVSlicing) + + def test_mlp_uses_gated_slicing(self): + assert isinstance( + qwen3vl_causal_lm_export["linear_fc1"], GatedMLPSlicing + ) + + def test_export_has_no_parallel_config(self): + """Export mappings should not have parallel configs.""" + for key in ["word_embeddings", "final_layernorm", "output_layer", + "input_layernorm", "linear_proj"]: + mapping = qwen3vl_causal_lm_export[key] + assert "parallel_config" not in mapping.func_kwargs + + +class TestQwen3VLImportExportSymmetry: + """Test that import and export mappings are consistent.""" + + def test_same_mcore_keys(self): + assert set(qwen3vl_causal_lm_import.keys()) == set( + qwen3vl_causal_lm_export.keys() + ) + + @pytest.mark.parametrize( + "key", + [ + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc2", + "router", + ], + ) + def test_matching_hf_prefixes(self, key): + """Import and export should map to the same HF prefix.""" + imp = qwen3vl_causal_lm_import[key] + exp = qwen3vl_causal_lm_export[key] + assert imp.target_name_or_prefix == exp.target_name_or_prefix, ( + f"{key}: import prefix '{imp.target_name_or_prefix}' != " + f"export prefix '{exp.target_name_or_prefix}'" + ) + + def test_qkv_matching_prefix(self): + imp = qwen3vl_causal_lm_import["linear_qkv"] + exp = qwen3vl_causal_lm_export["linear_qkv"] + assert imp.target_name_or_prefix == exp.target_name_or_prefix + + def test_mlp_fc1_matching_prefix(self): + imp = qwen3vl_causal_lm_import["linear_fc1"] + exp = qwen3vl_causal_lm_export["linear_fc1"] + assert imp.target_name_or_prefix == exp.target_name_or_prefix + + +class TestQwen3VLvsQwen3Difference: + """Test that Qwen3-VL differs from Qwen3 only in the language_model prefix.""" + + def test_same_keys_as_qwen3(self): + from modelopt.torch.export.plugins.mcore_qwen import ( + qwen3_causal_lm_export, + qwen3_causal_lm_import, + ) + + assert set(qwen3vl_causal_lm_import.keys()) == set( + qwen3_causal_lm_import.keys() + ) + assert set(qwen3vl_causal_lm_export.keys()) == set( + qwen3_causal_lm_export.keys() + ) + + @pytest.mark.parametrize( + "key", + [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + "router", + "local_experts.linear_fc1", + "local_experts.linear_fc2", + ], + ) + def test_vl_adds_language_model_prefix(self, key): + """Qwen3-VL should have 'language_model.' inserted after 'model.'.""" + from modelopt.torch.export.plugins.mcore_qwen import ( + qwen3_causal_lm_import, + ) + + qwen3_prefix = qwen3_causal_lm_import[key].target_name_or_prefix + qwen3vl_prefix = qwen3vl_causal_lm_import[key].target_name_or_prefix + expected = qwen3_prefix.replace("model.", "model.language_model.", 1) + assert qwen3vl_prefix == expected, ( + f"{key}: expected '{expected}', got '{qwen3vl_prefix}'" + ) + + def test_output_layer_same(self): + """lm_head is at root level for both Qwen3 and Qwen3-VL.""" + from modelopt.torch.export.plugins.mcore_qwen import ( + qwen3_causal_lm_import, + ) + + assert ( + qwen3vl_causal_lm_import["output_layer"].target_name_or_prefix + == qwen3_causal_lm_import["output_layer"].target_name_or_prefix + ) From 36da6deac2d5d419a460766e4de0799051ca7e63 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Wed, 13 May 2026 21:39:56 +0000 Subject: [PATCH 2/5] fix: ruff formatting and PT006 parametrize tuple fix Signed-off-by: Hung-Yueh Chiang --- tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py index 3f57cb9c478..c0d4cf9bb07 100644 --- a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py +++ b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py @@ -142,7 +142,7 @@ def test_layernorms_are_replicated(self, key): assert mapping.func_kwargs == REPLICATE @pytest.mark.parametrize( - "key,expected_kwargs", + ("key", "expected_kwargs"), [ ("word_embeddings", COL_TP), ("output_layer", COL_TP), From e8101a7f8a14cc82a7998a9d987c2195239506b0 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Wed, 13 May 2026 23:14:20 +0000 Subject: [PATCH 3/5] fix: apply ruff formatting to mcore_qwen3vl plugin and test files Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- modelopt/torch/export/plugins/mcore_common.py | 5 +-- .../torch/export/plugins/mcore_qwen3vl.py | 2 +- .../torch/export/test_mcore_qwen3vl.py | 33 ++++++++----------- 3 files changed, 15 insertions(+), 25 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_common.py b/modelopt/torch/export/plugins/mcore_common.py index 660e4eac96d..15395b7a1e5 100644 --- a/modelopt/torch/export/plugins/mcore_common.py +++ b/modelopt/torch/export/plugins/mcore_common.py @@ -39,10 +39,7 @@ qwen25_causal_lm_export, qwen25_causal_lm_import, ) -from .mcore_qwen3vl import ( - qwen3vl_causal_lm_export, - qwen3vl_causal_lm_import, -) +from .mcore_qwen3vl import qwen3vl_causal_lm_export, qwen3vl_causal_lm_import all_mcore_hf_export_mapping: dict[str, Any] = { "DeepseekV2ForCausalLM": deepseek_causal_lm_export, diff --git a/modelopt/torch/export/plugins/mcore_qwen3vl.py b/modelopt/torch/export/plugins/mcore_qwen3vl.py index 40eb99adb50..4dc3c63f4a2 100644 --- a/modelopt/torch/export/plugins/mcore_qwen3vl.py +++ b/modelopt/torch/export/plugins/mcore_qwen3vl.py @@ -117,4 +117,4 @@ "local_experts.linear_fc2": NameRemapping( "model.language_model.layers.{}.mlp.experts.{}.down_proj." ), -} \ No newline at end of file +} diff --git a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py index c0d4cf9bb07..c7a5efc47d4 100644 --- a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py +++ b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py @@ -121,9 +121,7 @@ def test_qkv_uses_merging(self): assert isinstance(qwen3vl_causal_lm_import["linear_qkv"], QKVMerging) def test_mlp_uses_gated_merging(self): - assert isinstance( - qwen3vl_causal_lm_import["linear_fc1"], GatedMLPMerging - ) + assert isinstance(qwen3vl_causal_lm_import["linear_fc1"], GatedMLPMerging) @pytest.mark.parametrize( "key", @@ -192,14 +190,17 @@ def test_qkv_uses_slicing(self): assert isinstance(qwen3vl_causal_lm_export["linear_qkv"], QKVSlicing) def test_mlp_uses_gated_slicing(self): - assert isinstance( - qwen3vl_causal_lm_export["linear_fc1"], GatedMLPSlicing - ) + assert isinstance(qwen3vl_causal_lm_export["linear_fc1"], GatedMLPSlicing) def test_export_has_no_parallel_config(self): """Export mappings should not have parallel configs.""" - for key in ["word_embeddings", "final_layernorm", "output_layer", - "input_layernorm", "linear_proj"]: + for key in [ + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_proj", + ]: mapping = qwen3vl_causal_lm_export[key] assert "parallel_config" not in mapping.func_kwargs @@ -208,9 +209,7 @@ class TestQwen3VLImportExportSymmetry: """Test that import and export mappings are consistent.""" def test_same_mcore_keys(self): - assert set(qwen3vl_causal_lm_import.keys()) == set( - qwen3vl_causal_lm_export.keys() - ) + assert set(qwen3vl_causal_lm_import.keys()) == set(qwen3vl_causal_lm_export.keys()) @pytest.mark.parametrize( "key", @@ -256,12 +255,8 @@ def test_same_keys_as_qwen3(self): qwen3_causal_lm_import, ) - assert set(qwen3vl_causal_lm_import.keys()) == set( - qwen3_causal_lm_import.keys() - ) - assert set(qwen3vl_causal_lm_export.keys()) == set( - qwen3_causal_lm_export.keys() - ) + assert set(qwen3vl_causal_lm_import.keys()) == set(qwen3_causal_lm_import.keys()) + assert set(qwen3vl_causal_lm_export.keys()) == set(qwen3_causal_lm_export.keys()) @pytest.mark.parametrize( "key", @@ -290,9 +285,7 @@ def test_vl_adds_language_model_prefix(self, key): qwen3_prefix = qwen3_causal_lm_import[key].target_name_or_prefix qwen3vl_prefix = qwen3vl_causal_lm_import[key].target_name_or_prefix expected = qwen3_prefix.replace("model.", "model.language_model.", 1) - assert qwen3vl_prefix == expected, ( - f"{key}: expected '{expected}', got '{qwen3vl_prefix}'" - ) + assert qwen3vl_prefix == expected, f"{key}: expected '{expected}', got '{qwen3vl_prefix}'" def test_output_layer_same(self): """lm_head is at root level for both Qwen3 and Qwen3-VL.""" From aecbbfae4db2034e62b950532e81fcc63ddbb82d Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Thu, 14 May 2026 16:03:07 +0000 Subject: [PATCH 4/5] fix: collapse single-item imports in test_mcore_qwen3vl per ruff Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- .../torch/export/test_mcore_qwen3vl.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py index c7a5efc47d4..a9b8ddd5a0f 100644 --- a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py +++ b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py @@ -32,7 +32,6 @@ qwen3vl_causal_lm_import, ) - # All mcore keys that a dense (non-MoE) Qwen3-VL model should have DENSE_MCORE_KEYS = { "word_embeddings", @@ -60,9 +59,7 @@ class TestQwen3VLRegistration: """Test that Qwen3-VL is registered in the global mapping.""" def test_registered_in_export_mapping(self): - from modelopt.torch.export.plugins.mcore_common import ( - all_mcore_hf_export_mapping, - ) + from modelopt.torch.export.plugins.mcore_common import all_mcore_hf_export_mapping assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_export_mapping assert ( @@ -71,9 +68,7 @@ def test_registered_in_export_mapping(self): ) def test_registered_in_import_mapping(self): - from modelopt.torch.export.plugins.mcore_common import ( - all_mcore_hf_import_mapping, - ) + from modelopt.torch.export.plugins.mcore_common import all_mcore_hf_import_mapping assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_import_mapping assert ( @@ -278,9 +273,7 @@ def test_same_keys_as_qwen3(self): ) def test_vl_adds_language_model_prefix(self, key): """Qwen3-VL should have 'language_model.' inserted after 'model.'.""" - from modelopt.torch.export.plugins.mcore_qwen import ( - qwen3_causal_lm_import, - ) + from modelopt.torch.export.plugins.mcore_qwen import qwen3_causal_lm_import qwen3_prefix = qwen3_causal_lm_import[key].target_name_or_prefix qwen3vl_prefix = qwen3vl_causal_lm_import[key].target_name_or_prefix @@ -289,9 +282,7 @@ def test_vl_adds_language_model_prefix(self, key): def test_output_layer_same(self): """lm_head is at root level for both Qwen3 and Qwen3-VL.""" - from modelopt.torch.export.plugins.mcore_qwen import ( - qwen3_causal_lm_import, - ) + from modelopt.torch.export.plugins.mcore_qwen import qwen3_causal_lm_import assert ( qwen3vl_causal_lm_import["output_layer"].target_name_or_prefix From 80495e6bad87688e77687889a2fdfa359b2233c6 Mon Sep 17 00:00:00 2001 From: Hung-Yueh Chiang Date: Thu, 14 May 2026 16:36:49 +0000 Subject: [PATCH 5/5] refactor: derive Qwen3-VL mcore mapping from Qwen3 via prefix rewrite Replace the hand-written dict literals in mcore_qwen3vl.py with a helper that derives the VL mapping from qwen3_causal_lm_import/export by inserting 'language_model.' after 'model.' in every prefix. lm_head. (root-level) is left unchanged. Remove TestQwen3VLvsQwen3Difference since it now tests the implementation against itself. Note visual encoder (model.visual.*) is intentionally excluded from the mapping. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Hung-Yueh Chiang --- .../torch/export/plugins/mcore_qwen3vl.py | 127 +++++------------- .../torch/export/test_mcore_qwen3vl.py | 49 ------- 2 files changed, 30 insertions(+), 146 deletions(-) diff --git a/modelopt/torch/export/plugins/mcore_qwen3vl.py b/modelopt/torch/export/plugins/mcore_qwen3vl.py index 4dc3c63f4a2..2f35b1291e0 100644 --- a/modelopt/torch/export/plugins/mcore_qwen3vl.py +++ b/modelopt/torch/export/plugins/mcore_qwen3vl.py @@ -15,106 +15,39 @@ """Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models. -Qwen3-VL model structure differs from Qwen3: -- Language model weights are under `model.language_model.` prefix -- Visual encoder weights are under `model.visual.` prefix +Qwen3-VL differs from Qwen3 in one structural way: language-model weights live +under ``model.language_model.`` instead of ``model.``, while ``lm_head.weight`` +remains at the root level. The mappings below are derived automatically from +the Qwen3 mappings by inserting ``language_model.`` after ``model.`` for every +prefix that starts with ``model.``. -This module handles the language model conversion for PTQ/QAT workflows. -Visual components are typically kept in full precision. +Note: the visual encoder (``model.visual.*``) is intentionally excluded — this +mapping covers only the language-model decoder used for quantization and export. -HuggingFace Qwen3-VL-8B structure: -- model.language_model.embed_tokens.weight -- model.language_model.layers.{L}.input_layernorm.weight -- model.language_model.layers.{L}.self_attn.q_proj.weight -- model.language_model.layers.{L}.self_attn.k_proj.weight -- model.language_model.layers.{L}.self_attn.v_proj.weight -- model.language_model.layers.{L}.self_attn.q_norm.weight -- model.language_model.layers.{L}.self_attn.k_norm.weight -- model.language_model.layers.{L}.self_attn.o_proj.weight -- model.language_model.layers.{L}.post_attention_layernorm.weight -- model.language_model.layers.{L}.mlp.gate_proj.weight -- model.language_model.layers.{L}.mlp.up_proj.weight -- model.language_model.layers.{L}.mlp.down_proj.weight -- model.language_model.norm.weight -- lm_head.weight +Reference: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct/blob/main/model.safetensors.index.json """ -from .mcore_custom import ( - COL_ETP, - COL_TP, - REPLICATE, - ROW_ETP, - ROW_TP, - CustomModuleMapping, - GatedMLPMerging, - GatedMLPSlicing, - NameRemapping, - QKVMerging, - QKVSlicing, -) +from .mcore_custom import CustomModuleMapping +from .mcore_qwen import qwen3_causal_lm_export, qwen3_causal_lm_import -# Import rules: HuggingFace -> Megatron Core -qwen3vl_causal_lm_import: dict[str, CustomModuleMapping] = { - # Embeddings - note the language_model prefix - "word_embeddings": NameRemapping("model.language_model.embed_tokens.", COL_TP), - # Final layer norm - "final_layernorm": NameRemapping("model.language_model.norm.", REPLICATE), - # Output layer (lm_head is at root level, not under language_model) - "output_layer": NameRemapping("lm_head.", COL_TP), - # Attention - input layernorm - "input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm.", REPLICATE), - # Attention - QKV projection (merged) - "linear_qkv": QKVMerging("model.language_model.layers.{}.self_attn.", COL_TP), - # Attention - output projection - "linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj.", ROW_TP), - # Attention - Q/K layer norms (Qwen3 uses RMSNorm on Q and K) - "q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm.", REPLICATE), - "k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm.", REPLICATE), - # MLP - pre-MLP layernorm (post_attention_layernorm in HF) - "pre_mlp_layernorm": NameRemapping( - "model.language_model.layers.{}.post_attention_layernorm.", REPLICATE - ), - # MLP - gate_proj + up_proj merged into linear_fc1 - "linear_fc1": GatedMLPMerging("model.language_model.layers.{}.mlp.", COL_TP), - # MLP - down_proj as linear_fc2 - "linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj.", ROW_TP), - # MoE support (for Qwen3-VL MoE variants like 30B-A3B) - "router": NameRemapping("model.language_model.layers.{}.mlp.gate.", REPLICATE), - "local_experts.linear_fc1": GatedMLPMerging( - "model.language_model.layers.{}.mlp.experts.{}.", COL_ETP - ), - "local_experts.linear_fc2": NameRemapping( - "model.language_model.layers.{}.mlp.experts.{}.down_proj.", ROW_ETP - ), -} -# Export rules: Megatron Core -> HuggingFace -qwen3vl_causal_lm_export: dict[str, CustomModuleMapping] = { - # Embeddings - "word_embeddings": NameRemapping("model.language_model.embed_tokens."), - # Final layer norm - "final_layernorm": NameRemapping("model.language_model.norm."), - # Output layer - "output_layer": NameRemapping("lm_head."), - # Attention - input layernorm - "input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm."), - # Attention - QKV projection (sliced back to separate q/k/v) - "linear_qkv": QKVSlicing("model.language_model.layers.{}.self_attn."), - # Attention - output projection - "linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj."), - # Attention - Q/K layer norms - "q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm."), - "k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm."), - # MLP - pre-MLP layernorm - "pre_mlp_layernorm": NameRemapping("model.language_model.layers.{}.post_attention_layernorm."), - # MLP - linear_fc1 sliced back to gate_proj + up_proj - "linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp."), - # MLP - down_proj - "linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj."), - # MoE support - "router": NameRemapping("model.language_model.layers.{}.mlp.gate."), - "local_experts.linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp.experts.{}."), - "local_experts.linear_fc2": NameRemapping( - "model.language_model.layers.{}.mlp.experts.{}.down_proj." - ), -} +def _with_language_model_prefix( + mapping: dict[str, CustomModuleMapping], +) -> dict[str, CustomModuleMapping]: + """Derive a VL mapping from a base Qwen3 mapping. + + Rewrites every ``target_name_or_prefix`` that starts with ``model.`` to + ``model.language_model.``. Prefixes that do not start with + ``model.`` (e.g. ``lm_head.``) are left unchanged. + """ + result = {} + for key, m in mapping.items(): + prefix = m.target_name_or_prefix + if prefix.startswith("model."): + prefix = "model.language_model." + prefix[len("model.") :] + result[key] = type(m)(target_name_or_prefix=prefix, func_kwargs=m.func_kwargs) + return result + + +qwen3vl_causal_lm_import = _with_language_model_prefix(qwen3_causal_lm_import) +qwen3vl_causal_lm_export = _with_language_model_prefix(qwen3_causal_lm_export) diff --git a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py index a9b8ddd5a0f..f5f62058bf4 100644 --- a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py +++ b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py @@ -239,52 +239,3 @@ def test_mlp_fc1_matching_prefix(self): imp = qwen3vl_causal_lm_import["linear_fc1"] exp = qwen3vl_causal_lm_export["linear_fc1"] assert imp.target_name_or_prefix == exp.target_name_or_prefix - - -class TestQwen3VLvsQwen3Difference: - """Test that Qwen3-VL differs from Qwen3 only in the language_model prefix.""" - - def test_same_keys_as_qwen3(self): - from modelopt.torch.export.plugins.mcore_qwen import ( - qwen3_causal_lm_export, - qwen3_causal_lm_import, - ) - - assert set(qwen3vl_causal_lm_import.keys()) == set(qwen3_causal_lm_import.keys()) - assert set(qwen3vl_causal_lm_export.keys()) == set(qwen3_causal_lm_export.keys()) - - @pytest.mark.parametrize( - "key", - [ - "word_embeddings", - "final_layernorm", - "input_layernorm", - "linear_qkv", - "linear_proj", - "q_layernorm", - "k_layernorm", - "pre_mlp_layernorm", - "linear_fc1", - "linear_fc2", - "router", - "local_experts.linear_fc1", - "local_experts.linear_fc2", - ], - ) - def test_vl_adds_language_model_prefix(self, key): - """Qwen3-VL should have 'language_model.' inserted after 'model.'.""" - from modelopt.torch.export.plugins.mcore_qwen import qwen3_causal_lm_import - - qwen3_prefix = qwen3_causal_lm_import[key].target_name_or_prefix - qwen3vl_prefix = qwen3vl_causal_lm_import[key].target_name_or_prefix - expected = qwen3_prefix.replace("model.", "model.language_model.", 1) - assert qwen3vl_prefix == expected, f"{key}: expected '{expected}', got '{qwen3vl_prefix}'" - - def test_output_layer_same(self): - """lm_head is at root level for both Qwen3 and Qwen3-VL.""" - from modelopt.torch.export.plugins.mcore_qwen import qwen3_causal_lm_import - - assert ( - qwen3vl_causal_lm_import["output_layer"].target_name_or_prefix - == qwen3_causal_lm_import["output_layer"].target_name_or_prefix - )