diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 62f2b0041cb..8fd414ec5fc 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -24,6 +24,7 @@ Changelog - Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md `_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model. - Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md `__ for usage. - DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``. +- Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL and supports both dense and MoE variants. 0.44 (2026-05-18) ^^^^^^^^^^^^^^^^^ diff --git a/docs/source/deployment/3_unified_hf.rst b/docs/source/deployment/3_unified_hf.rst index 9124164b576..6664f987f72 100644 --- a/docs/source/deployment/3_unified_hf.rst +++ b/docs/source/deployment/3_unified_hf.rst @@ -61,6 +61,7 @@ Models: * Llama 4, 3.x (FP8, NVFP4) * Qwen 3, 2.5 (FP8, NVFP4) * Qwen 3 MoE (FP8, NVFP4) + * Qwen 3-VL (FP8, NVFP4) * Deepseek R1/V3 (NVFP4) * Mixtral 8x7B (FP8, NVFP4) * Medusa (FP8) diff --git a/modelopt/torch/export/plugins/mcore_common.py b/modelopt/torch/export/plugins/mcore_common.py index d5bab9b4ece..15395b7a1e5 100644 --- a/modelopt/torch/export/plugins/mcore_common.py +++ b/modelopt/torch/export/plugins/mcore_common.py @@ -39,6 +39,7 @@ qwen25_causal_lm_export, qwen25_causal_lm_import, ) +from .mcore_qwen3vl import qwen3vl_causal_lm_export, qwen3vl_causal_lm_import all_mcore_hf_export_mapping: dict[str, Any] = { "DeepseekV2ForCausalLM": deepseek_causal_lm_export, @@ -54,6 +55,7 @@ "Qwen3MoeForCausalLM": qwen3_causal_lm_export, "Qwen2ForCausalLM": qwen25_causal_lm_export, "GptOssForCausalLM": gptoss_causal_lm_export, + "Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_export, } all_mcore_hf_import_mapping: dict[str, Any] = { @@ -66,4 +68,5 @@ "Qwen3MoeForCausalLM": qwen3_causal_lm_import, "Qwen2ForCausalLM": qwen25_causal_lm_import, "GptOssForCausalLM": gptoss_causal_lm_import, + "Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_import, } diff --git a/modelopt/torch/export/plugins/mcore_qwen3vl.py b/modelopt/torch/export/plugins/mcore_qwen3vl.py new file mode 100644 index 00000000000..2f35b1291e0 --- /dev/null +++ b/modelopt/torch/export/plugins/mcore_qwen3vl.py @@ -0,0 +1,53 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models. + +Qwen3-VL differs from Qwen3 in one structural way: language-model weights live +under ``model.language_model.`` instead of ``model.``, while ``lm_head.weight`` +remains at the root level. The mappings below are derived automatically from +the Qwen3 mappings by inserting ``language_model.`` after ``model.`` for every +prefix that starts with ``model.``. + +Note: the visual encoder (``model.visual.*``) is intentionally excluded — this +mapping covers only the language-model decoder used for quantization and export. + +Reference: https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct/blob/main/model.safetensors.index.json +""" + +from .mcore_custom import CustomModuleMapping +from .mcore_qwen import qwen3_causal_lm_export, qwen3_causal_lm_import + + +def _with_language_model_prefix( + mapping: dict[str, CustomModuleMapping], +) -> dict[str, CustomModuleMapping]: + """Derive a VL mapping from a base Qwen3 mapping. + + Rewrites every ``target_name_or_prefix`` that starts with ``model.`` to + ``model.language_model.``. Prefixes that do not start with + ``model.`` (e.g. ``lm_head.``) are left unchanged. + """ + result = {} + for key, m in mapping.items(): + prefix = m.target_name_or_prefix + if prefix.startswith("model."): + prefix = "model.language_model." + prefix[len("model.") :] + result[key] = type(m)(target_name_or_prefix=prefix, func_kwargs=m.func_kwargs) + return result + + +qwen3vl_causal_lm_import = _with_language_model_prefix(qwen3_causal_lm_import) +qwen3vl_causal_lm_export = _with_language_model_prefix(qwen3_causal_lm_export) diff --git a/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py new file mode 100644 index 00000000000..f5f62058bf4 --- /dev/null +++ b/tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py @@ -0,0 +1,241 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Qwen3-VL Megatron Core export/import plugin.""" + +import pytest + +from modelopt.torch.export.plugins.mcore_custom import ( + COL_TP, + REPLICATE, + ROW_TP, + GatedMLPMerging, + GatedMLPSlicing, + NameRemapping, + QKVMerging, + QKVSlicing, +) +from modelopt.torch.export.plugins.mcore_qwen3vl import ( + qwen3vl_causal_lm_export, + qwen3vl_causal_lm_import, +) + +# All mcore keys that a dense (non-MoE) Qwen3-VL model should have +DENSE_MCORE_KEYS = { + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", +} + +# Additional MoE keys +MOE_MCORE_KEYS = { + "router", + "local_experts.linear_fc1", + "local_experts.linear_fc2", +} + + +class TestQwen3VLRegistration: + """Test that Qwen3-VL is registered in the global mapping.""" + + def test_registered_in_export_mapping(self): + from modelopt.torch.export.plugins.mcore_common import all_mcore_hf_export_mapping + + assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_export_mapping + assert ( + all_mcore_hf_export_mapping["Qwen3VLForConditionalGeneration"] + is qwen3vl_causal_lm_export + ) + + def test_registered_in_import_mapping(self): + from modelopt.torch.export.plugins.mcore_common import all_mcore_hf_import_mapping + + assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_import_mapping + assert ( + all_mcore_hf_import_mapping["Qwen3VLForConditionalGeneration"] + is qwen3vl_causal_lm_import + ) + + +class TestQwen3VLImportMapping: + """Test the HuggingFace -> Megatron Core import mapping.""" + + def test_has_all_dense_keys(self): + assert DENSE_MCORE_KEYS.issubset(qwen3vl_causal_lm_import.keys()) + + def test_has_all_moe_keys(self): + assert MOE_MCORE_KEYS.issubset(qwen3vl_causal_lm_import.keys()) + + def test_language_model_prefix(self): + """Qwen3-VL uses model.language_model. prefix (not model.).""" + prefix_keys = [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + ] + for key in prefix_keys: + mapping = qwen3vl_causal_lm_import[key] + assert "model.language_model." in mapping.target_name_or_prefix, ( + f"{key}: expected 'model.language_model.' prefix, " + f"got '{mapping.target_name_or_prefix}'" + ) + + def test_output_layer_at_root(self): + """lm_head is at root level, not under language_model.""" + mapping = qwen3vl_causal_lm_import["output_layer"] + assert mapping.target_name_or_prefix == "lm_head." + + def test_qkv_uses_merging(self): + assert isinstance(qwen3vl_causal_lm_import["linear_qkv"], QKVMerging) + + def test_mlp_uses_gated_merging(self): + assert isinstance(qwen3vl_causal_lm_import["linear_fc1"], GatedMLPMerging) + + @pytest.mark.parametrize( + "key", + [ + "input_layernorm", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "final_layernorm", + ], + ) + def test_layernorms_are_replicated(self, key): + """Layernorms should use REPLICATE (empty func_kwargs).""" + mapping = qwen3vl_causal_lm_import[key] + assert isinstance(mapping, NameRemapping) + assert mapping.func_kwargs == REPLICATE + + @pytest.mark.parametrize( + ("key", "expected_kwargs"), + [ + ("word_embeddings", COL_TP), + ("output_layer", COL_TP), + ("linear_proj", ROW_TP), + ], + ) + def test_tp_sharding(self, key, expected_kwargs): + mapping = qwen3vl_causal_lm_import[key] + assert mapping.func_kwargs == expected_kwargs + + +class TestQwen3VLExportMapping: + """Test the Megatron Core -> HuggingFace export mapping.""" + + def test_has_all_dense_keys(self): + assert DENSE_MCORE_KEYS.issubset(qwen3vl_causal_lm_export.keys()) + + def test_has_all_moe_keys(self): + assert MOE_MCORE_KEYS.issubset(qwen3vl_causal_lm_export.keys()) + + def test_language_model_prefix(self): + """Export paths should also use model.language_model. prefix.""" + prefix_keys = [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + ] + for key in prefix_keys: + mapping = qwen3vl_causal_lm_export[key] + assert "model.language_model." in mapping.target_name_or_prefix, ( + f"{key}: expected 'model.language_model.' prefix, " + f"got '{mapping.target_name_or_prefix}'" + ) + + def test_output_layer_at_root(self): + mapping = qwen3vl_causal_lm_export["output_layer"] + assert mapping.target_name_or_prefix == "lm_head." + + def test_qkv_uses_slicing(self): + assert isinstance(qwen3vl_causal_lm_export["linear_qkv"], QKVSlicing) + + def test_mlp_uses_gated_slicing(self): + assert isinstance(qwen3vl_causal_lm_export["linear_fc1"], GatedMLPSlicing) + + def test_export_has_no_parallel_config(self): + """Export mappings should not have parallel configs.""" + for key in [ + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_proj", + ]: + mapping = qwen3vl_causal_lm_export[key] + assert "parallel_config" not in mapping.func_kwargs + + +class TestQwen3VLImportExportSymmetry: + """Test that import and export mappings are consistent.""" + + def test_same_mcore_keys(self): + assert set(qwen3vl_causal_lm_import.keys()) == set(qwen3vl_causal_lm_export.keys()) + + @pytest.mark.parametrize( + "key", + [ + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc2", + "router", + ], + ) + def test_matching_hf_prefixes(self, key): + """Import and export should map to the same HF prefix.""" + imp = qwen3vl_causal_lm_import[key] + exp = qwen3vl_causal_lm_export[key] + assert imp.target_name_or_prefix == exp.target_name_or_prefix, ( + f"{key}: import prefix '{imp.target_name_or_prefix}' != " + f"export prefix '{exp.target_name_or_prefix}'" + ) + + def test_qkv_matching_prefix(self): + imp = qwen3vl_causal_lm_import["linear_qkv"] + exp = qwen3vl_causal_lm_export["linear_qkv"] + assert imp.target_name_or_prefix == exp.target_name_or_prefix + + def test_mlp_fc1_matching_prefix(self): + imp = qwen3vl_causal_lm_import["linear_fc1"] + exp = qwen3vl_causal_lm_export["linear_fc1"] + assert imp.target_name_or_prefix == exp.target_name_or_prefix