Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Changelog
- Add support for ``active_params`` (for MoE models) and ``memory_mb`` constraints in Minitron pruning on top of existing ``params`` constraint. You can also provide multiple constraints. See `examples/pruning/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/pruning>`_ for more details. The underlying utility functions ``mcore_param_count``, ``mcore_memory_footprint_mb``, and ``print_mcore_model_stats`` in ``modelopt.torch.nas.plugins.megatron_model_stats`` are also available for standalone use to compute parameter counts and memory footprints (weights + KV-cache + Mamba state) for any Megatron-Core model.
- Add ``--cast_mxfp4_to_nvfp4`` flag to ``examples/llm_ptq/hf_ptq.py`` for closed-form, bit-exact MXFP4 → NVFP4 weight conversion. Supports the GPT-OSS family (``openai/gpt-oss-20b``, ``openai/gpt-oss-120b``). See `examples/llm_ptq/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/llm_ptq#mxfp4--nvfp4-cast-for-gpt-oss>`__ for usage.
- DeepSeek PTQ (``examples/deepseek/ptq.py``) now defaults to native top-k calibration with post-hoc per-layer peer-max sync of expert ``input_quantizer.amax``; the all-experts path is preserved behind ``--calib_all_experts``.
- Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL and supports both dense and MoE variants.

0.44 (2026-05-18)
^^^^^^^^^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions docs/source/deployment/3_unified_hf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Models:
* Llama 4, 3.x (FP8, NVFP4)
* Qwen 3, 2.5 (FP8, NVFP4)
* Qwen 3 MoE (FP8, NVFP4)
* Qwen 3-VL (FP8, NVFP4)
* Deepseek R1/V3 (NVFP4)
* Mixtral 8x7B (FP8, NVFP4)
* Medusa (FP8)
Expand Down
3 changes: 3 additions & 0 deletions modelopt/torch/export/plugins/mcore_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
qwen25_causal_lm_export,
qwen25_causal_lm_import,
)
from .mcore_qwen3vl import qwen3vl_causal_lm_export, qwen3vl_causal_lm_import

all_mcore_hf_export_mapping: dict[str, Any] = {
"DeepseekV2ForCausalLM": deepseek_causal_lm_export,
Expand All @@ -54,6 +55,7 @@
"Qwen3MoeForCausalLM": qwen3_causal_lm_export,
"Qwen2ForCausalLM": qwen25_causal_lm_export,
"GptOssForCausalLM": gptoss_causal_lm_export,
"Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_export,
}

all_mcore_hf_import_mapping: dict[str, Any] = {
Expand All @@ -66,4 +68,5 @@
"Qwen3MoeForCausalLM": qwen3_causal_lm_import,
"Qwen2ForCausalLM": qwen25_causal_lm_import,
"GptOssForCausalLM": gptoss_causal_lm_import,
"Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_import,
}
120 changes: 120 additions & 0 deletions modelopt/torch/export/plugins/mcore_qwen3vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models.

Qwen3-VL model structure differs from Qwen3:
- Language model weights are under `model.language_model.` prefix
- Visual encoder weights are under `model.visual.` prefix

This module handles the language model conversion for PTQ/QAT workflows.
Visual components are typically kept in full precision.

HuggingFace Qwen3-VL-8B structure:
- model.language_model.embed_tokens.weight
- model.language_model.layers.{L}.input_layernorm.weight
- model.language_model.layers.{L}.self_attn.q_proj.weight
- model.language_model.layers.{L}.self_attn.k_proj.weight
- model.language_model.layers.{L}.self_attn.v_proj.weight
- model.language_model.layers.{L}.self_attn.q_norm.weight
- model.language_model.layers.{L}.self_attn.k_norm.weight
- model.language_model.layers.{L}.self_attn.o_proj.weight
- model.language_model.layers.{L}.post_attention_layernorm.weight
- model.language_model.layers.{L}.mlp.gate_proj.weight
- model.language_model.layers.{L}.mlp.up_proj.weight
- model.language_model.layers.{L}.mlp.down_proj.weight
- model.language_model.norm.weight
- lm_head.weight
"""

from .mcore_custom import (
COL_ETP,
COL_TP,
REPLICATE,
ROW_ETP,
ROW_TP,
CustomModuleMapping,
GatedMLPMerging,
GatedMLPSlicing,
NameRemapping,
QKVMerging,
QKVSlicing,
)

# Import rules: HuggingFace -> Megatron Core
qwen3vl_causal_lm_import: dict[str, CustomModuleMapping] = {
# Embeddings - note the language_model prefix
"word_embeddings": NameRemapping("model.language_model.embed_tokens.", COL_TP),
# Final layer norm
"final_layernorm": NameRemapping("model.language_model.norm.", REPLICATE),
# Output layer (lm_head is at root level, not under language_model)
"output_layer": NameRemapping("lm_head.", COL_TP),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Export] Worth double-checking against the published Qwen3-VL checkpoint: in recent transformers (≥4.45), several *ForConditionalGeneration VLMs (including the Qwen2.5-VL / Qwen3-VL families) moved lm_head into the inner language model — i.e. the safetensors key is model.language_model.lm_head.weight, not lm_head.weight at root. If that's the case for the Qwen3-VL release you're targeting, both output_layer mappings (here and line 98) will silently fail to find the tensor on import and write to the wrong location on export, and tie_word_embeddings interaction will also be off.

The PR description says you've round-tripped Qwen3-VL-8B-Instruct, so this may already be verified — but the Qwen3 mapping (mcore_qwen.py:35) inherited a root-level lm_head. from a different architecture pattern, and copying it without checking is the most likely place this PR could be wrong. Worth grepping the actual safetensors keys (safe_open(...).keys()) once and confirming.

# Attention - input layernorm
"input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm.", REPLICATE),
# Attention - QKV projection (merged)
"linear_qkv": QKVMerging("model.language_model.layers.{}.self_attn.", COL_TP),
# Attention - output projection
"linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj.", ROW_TP),
# Attention - Q/K layer norms (Qwen3 uses RMSNorm on Q and K)
"q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm.", REPLICATE),
"k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm.", REPLICATE),
# MLP - pre-MLP layernorm (post_attention_layernorm in HF)
"pre_mlp_layernorm": NameRemapping(
"model.language_model.layers.{}.post_attention_layernorm.", REPLICATE
),
# MLP - gate_proj + up_proj merged into linear_fc1
"linear_fc1": GatedMLPMerging("model.language_model.layers.{}.mlp.", COL_TP),
# MLP - down_proj as linear_fc2
"linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj.", ROW_TP),
# MoE support (for Qwen3-VL MoE variants like 30B-A3B)
"router": NameRemapping("model.language_model.layers.{}.mlp.gate.", REPLICATE),
"local_experts.linear_fc1": GatedMLPMerging(
"model.language_model.layers.{}.mlp.experts.{}.", COL_ETP
),
"local_experts.linear_fc2": NameRemapping(
"model.language_model.layers.{}.mlp.experts.{}.down_proj.", ROW_ETP
),
}

# Export rules: Megatron Core -> HuggingFace
qwen3vl_causal_lm_export: dict[str, CustomModuleMapping] = {
# Embeddings
"word_embeddings": NameRemapping("model.language_model.embed_tokens."),
# Final layer norm
"final_layernorm": NameRemapping("model.language_model.norm."),
# Output layer
"output_layer": NameRemapping("lm_head."),
# Attention - input layernorm
"input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm."),
# Attention - QKV projection (sliced back to separate q/k/v)
"linear_qkv": QKVSlicing("model.language_model.layers.{}.self_attn."),
# Attention - output projection
"linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj."),
# Attention - Q/K layer norms
"q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm."),
"k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm."),
# MLP - pre-MLP layernorm
"pre_mlp_layernorm": NameRemapping("model.language_model.layers.{}.post_attention_layernorm."),
# MLP - linear_fc1 sliced back to gate_proj + up_proj
"linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp."),
# MLP - down_proj
"linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj."),
# MoE support
"router": NameRemapping("model.language_model.layers.{}.mlp.gate."),
"local_experts.linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp.experts.{}."),
"local_experts.linear_fc2": NameRemapping(
"model.language_model.layers.{}.mlp.experts.{}.down_proj."
),
}
Loading
Loading