Skip to content

Add Qwen3VL MCore Export support from PR 895#1482

Open
hychiang-git wants to merge 4 commits into
mainfrom
hungyueh/pr-895
Open

Add Qwen3VL MCore Export support from PR 895#1482
hychiang-git wants to merge 4 commits into
mainfrom
hungyueh/pr-895

Conversation

@hychiang-git
Copy link
Copy Markdown
Contributor

@hychiang-git hychiang-git commented May 13, 2026

This PR is duplicated from PR #895. Since the original branch source is not available now, we create a new branch where we can update this PR.

What does this PR do?

new feature:

Overview: Add Qwen3-VL (Vision-Language) model support to the Megatron Core export/import
plugin, enabling HuggingFace-to-mcore weight conversion for PTQ/QAT/QAD workflows

Details

Qwen3-VL has a different weight structure from Qwen3 text-only models:

  • Language model weights are under model.language_model. prefix (not model.)
  • Visual encoder weights are under model.visual. prefix
  • The lm_head is at root level, not nested under language_model

This PR adds:

  • mcore_qwen3vl.py: Import/export weight mapping rules between HuggingFace
    Qwen3VLForConditionalGeneration and Megatron Core, handling the language_model prefix for
    all decoder layers, QKV merging/slicing, gated MLP merging/slicing, Q/K layer norms.
  • mcore_common.py: Registers Qwen3VLForConditionalGeneration in
    all_mcore_hf_export_mapping and all_mcore_hf_import_mapping.

Usage

  • Import Qwen3-VL from HuggingFace to MCore, and export the MCore model
#!/usr/bin/env python3
"""Minimal example: Load Qwen3-VL with visual encoder from HF + language model via mcore mapping.

This script demonstrates the two-step loading process for Qwen3-VL:
  1. Visual encoder: loaded from HuggingFace directly in Qwen3VLModel.__init__
  2. Language model: imported via import_mcore_gpt_from_hf (uses mcore_qwen3vl.py mapping)

Usage (single GPU):
    python load_qwen3vl_example.py

Usage (multi-GPU with TP=2):
    torchrun --nproc_per_node=2 load_qwen3vl_example.py \
        --tensor-model-parallel-size 2

Requirements:
    pip install torch "transformers>=4.45,<5" flash-attn nvidia-modelopt safetensors
"""

import os
import sys

import torch

# Add Megatron-LM to path (adjust as needed)
MEGATRON_PATH = os.environ.get(
    "MEGATRON_PATH",
    os.path.join(os.path.dirname(__file__), "Megatron-LM-public"),
)
if os.path.isdir(MEGATRON_PATH):
    sys.path.insert(0, MEGATRON_PATH)

# For single-GPU runs (no torchrun), set distributed env vars
if "MASTER_ADDR" not in os.environ:
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29501"
    os.environ["RANK"] = "0"
    os.environ["WORLD_SIZE"] = "1"

# Qwen3-VL-8B architecture args — injected into sys.argv so Megatron's
# argparse picks them up correctly (args_defaults gets overridden).
QWEN3VL_8B_ARGS = [
    "--num-layers", "36",
    "--hidden-size", "4096",
    "--ffn-hidden-size", "12288",
    "--num-attention-heads", "32",
    "--group-query-attention",
    "--num-query-groups", "8",
    "--seq-length", "4096",
    "--max-position-embeddings", "32768",
    "--norm-epsilon", "1e-6",
    "--swiglu",
    "--bf16",
    "--untie-embeddings-and-output-weights",
    "--position-embedding-type", "rope",
    "--rotary-base", "1000000",
    "--normalization", "RMSNorm",
    "--qk-layernorm",
    "--disable-bias-linear",
    "--img-h", "384",
    "--img-w", "384",
    "--micro-batch-size", "1",
    "--tokenizer-type", "HuggingFaceTokenizer",
    "--tokenizer-model", "Qwen/Qwen3-VL-8B-Instruct",
    "--no-load-rng",
    "--no-load-optim",
    "--no-gradient-accumulation-fusion",
    "--padded-vocab-size", "151936",
]

# Only inject defaults for args not already on the command line
for arg in QWEN3VL_8B_ARGS:
    if arg.startswith("--") and arg not in sys.argv:
        idx = QWEN3VL_8B_ARGS.index(arg)
        # Check if next element is a value (not a flag)
        if idx + 1 < len(QWEN3VL_8B_ARGS) and not QWEN3VL_8B_ARGS[idx + 1].startswith("--"):
            sys.argv.extend([arg, QWEN3VL_8B_ARGS[idx + 1]])
        else:
            sys.argv.append(arg)


def main():
    # ── 1. Initialize Megatron distributed environment ──────────────────
    from megatron.core.enums import ModelType
    from megatron.training import get_args, get_model, initialize_megatron

    def extra_args(parser):
        group = parser.add_argument_group("Qwen3-VL example")
        group.add_argument(
            "--hf-model-name",
            type=str,
            default="Qwen/Qwen3-VL-8B-Instruct",
            help="HuggingFace model name or local path",
        )
        return parser

    initialize_megatron(
        extra_args_provider=extra_args,
        args_defaults={"no_load_rng": True, "no_load_optim": True},
    )
    args = get_args()

    # ── 2. Build Qwen3VLModel (visual encoder auto-loaded from HF) ─────
    #
    # Qwen3VLModel.__init__ does:
    #   self.visual = Qwen3VLVisionEncoder(hf_model_name=...)
    #       → calls AutoModelForVision2Seq.from_pretrained(hf_model_name)
    #       → extracts hf_model.visual, deletes the rest
    #       → visual encoder weights are loaded ✓
    #
    #   self.language_model = GPTModel(config, layer_spec, ...)
    #       → creates empty MCore GPTModel with random weights
    #       → language model weights NOT loaded yet ✗
    from copy import deepcopy

    from megatron.core import parallel_state
    from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
    from megatron.core.models.multimodal.qwen3_vl_model import Qwen3VLModel
    from megatron.training.arguments import core_transformer_config_from_args
    from megatron.training.utils import print_rank_0, unwrap_model

    language_config = core_transformer_config_from_args(args)
    # Local spec uses TorchNorm which doesn't support persist_layer_norm
    language_config.persist_layer_norm = False
    # Use local spec (not TE spec) for HF weight import compatibility.
    # TE spec fuses layernorm+linear into TELayerNormColumnParallelLinear
    # which expects different state_dict keys than plain HF weights.
    language_layer_spec = get_gpt_layer_local_spec(
        num_experts=args.num_experts,
        moe_grouped_gemm=args.moe_grouped_gemm,
        qk_layernorm=args.qk_layernorm,
    )

    vision_config = deepcopy(language_config)
    vision_config.context_parallel_size = 1

    def model_provider(
        pre_process=True, post_process=True, parallel_output=True,
        config=None, pg_collection=None, vp_stage=None,
    ):
        model = Qwen3VLModel(
            language_transformer_config=language_config,
            language_transformer_layer_spec=language_layer_spec,
            language_vocab_size=args.padded_vocab_size,
            language_max_sequence_length=args.max_position_embeddings,
            vision_transformer_config=vision_config,
            hf_model_name=args.hf_model_name,
            parallel_output=parallel_output,
            language_position_embedding_type=args.position_embedding_type,
            language_rotary_percent=args.rotary_percent,
            language_rotary_base=args.rotary_base,
            pre_process=pre_process,
            post_process=post_process,
            add_encoder=parallel_state.is_pipeline_first_stage(),
            add_decoder=True,
            img_h=args.img_h,
            img_w=args.img_w,
        )
        return model

    model = get_model(model_provider, ModelType.encoder_or_decoder, wrap_with_ddp=False)
    unwrapped = unwrap_model(model)[0]

    print_rank_0(f"Visual encoder : {type(unwrapped.visual)}")
    print_rank_0(f"Language model : {type(unwrapped.language_model)}")

    # At this point:
    #   unwrapped.visual         → loaded from HF ✓
    #   unwrapped.language_model → random weights ✗

    # ── 3. Import language model weights via mcore mapping ──────────────
    #
    # import_mcore_gpt_from_hf internally:
    #   1. Reads config.json → finds "Qwen3VLForConditionalGeneration"
    #   2. Looks up qwen3vl_causal_lm_import from mcore_qwen3vl.py
    #   3. Maps HF weight names to mcore names:
    #      "model.language_model.layers.{L}.self_attn.q/k/v_proj" → "linear_qkv" (QKVMerging)
    #      "model.language_model.layers.{L}.mlp.gate/up_proj"     → "linear_fc1" (GatedMLPMerging)
    #      "model.language_model.embed_tokens"                     → "word_embeddings"
    #      "model.language_model.norm"                              → "final_layernorm"
    #      "lm_head"                                                → "output_layer"
    from modelopt.torch.export import import_mcore_gpt_from_hf

    workspace_dir = os.environ.get("WORKSPACE_DIR", "/tmp/mcore_workspace")
    import_dtype = torch.bfloat16 if args.bf16 else torch.float16

    print_rank_0(f"Importing language model weights from {args.hf_model_name} ...")
    import_mcore_gpt_from_hf(
        model=unwrapped.language_model,  # only the GPTModel, NOT the full VLM
        pretrained_model_path=args.hf_model_name,
        workspace_dir=workspace_dir,
        dtype=import_dtype,
    )
    print_rank_0("Language model weights imported ✓")

    # Now both components are loaded:
    #   unwrapped.visual         → from HF directly ✓
    #   unwrapped.language_model → from HF via mcore mapping ✓

    # ── 4. Verify: print parameter stats ────────────────────────────────
    if unwrapped.visual is not None:
        visual_params = sum(p.numel() for p in unwrapped.visual.parameters())
        print_rank_0(f"Visual encoder params : {visual_params:,}")
    lm_params = sum(p.numel() for p in unwrapped.language_model.parameters())
    print_rank_0(f"Language model params : {lm_params:,}")
    total = sum(p.numel() for p in unwrapped.parameters())
    print_rank_0(f"Total params          : {total:,}")

    # ── 5. Export back to HF format ────────────────────────────────────
    #
    # Two-step process:
    #   Step A: Export language model weights via mcore mapping (reverse of step 3)
    #   Step B: Copy visual encoder weights from original HF checkpoint
    import shutil
    from glob import glob

    from modelopt.torch.export import export_mcore_gpt_to_hf
    from safetensors import safe_open
    from safetensors.torch import save_file

    export_dir = os.environ.get("EXPORT_DIR", "/tmp/qwen3vl-exported")
    os.makedirs(export_dir, exist_ok=True)

    # Step A: Export language model (mcore → HF weight names)
    # Internally uses qwen3vl_causal_lm_export from mcore_qwen3vl.py:
    #   "linear_qkv" → q_proj/k_proj/v_proj (QKVSlicing)
    #   "linear_fc1" → gate_proj/up_proj     (GatedMLPSlicing)
    #   etc.
    print_rank_0(f"Exporting language model to {export_dir} ...")
    export_mcore_gpt_to_hf(
        model=unwrapped.language_model,  # only the GPTModel
        pretrained_model_name_or_path=args.hf_model_name,
        export_dir=export_dir,
        dtype=torch.bfloat16,
    )
    print_rank_0("Language model exported ✓")

    # Step B: Copy visual encoder from original HF checkpoint
    # (Only rank 0 does file I/O)
    if torch.distributed.get_rank() == 0:
        # Resolve HF model to local cache path
        hf_local_path = args.hf_model_name
        if not os.path.isdir(hf_local_path):
            from huggingface_hub import snapshot_download
            hf_local_path = snapshot_download(
                repo_id=args.hf_model_name, local_files_only=True,
            )

        # Extract visual weights from original HF safetensors
        visual_state_dict = {}
        for sf_file in glob(os.path.join(hf_local_path, "*.safetensors")):
            with safe_open(sf_file, framework="pt", device="cpu") as f:
                for key in f.keys():
                    if key.startswith("visual") or key.startswith("model.visual"):
                        visual_state_dict[key] = f.get_tensor(key)

        if visual_state_dict:
            print_rank_0(f"Found {len(visual_state_dict)} visual tensors")

            # Load exported language model weights
            all_weights = {}
            for sf_file in glob(os.path.join(export_dir, "model*.safetensors")):
                with safe_open(sf_file, framework="pt", device="cpu") as f:
                    for key in f.keys():
                        all_weights[key] = f.get_tensor(key)

            # Merge visual + language model weights
            all_weights.update(visual_state_dict)
            merged_file = os.path.join(export_dir, "model.safetensors")
            save_file(all_weights, merged_file)
            print_rank_0(f"Merged checkpoint saved to {merged_file}")

            # Clean up old shard files
            for sf_file in glob(os.path.join(export_dir, "model*.safetensors")):
                if sf_file != merged_file:
                    os.remove(sf_file)
            for json_file in glob(os.path.join(export_dir, "model-*.json")):
                os.remove(json_file)

        # Copy VLM-specific config files
        for fname in ["preprocessor_config.json", "processor_config.json",
                       "chat_template.json"]:
            src = os.path.join(hf_local_path, fname)
            dst = os.path.join(export_dir, fname)
            if os.path.exists(src) and not os.path.exists(dst):
                shutil.copy(src, dst)
                print_rank_0(f"Copied {fname}")

        # Print export size
        total_size = sum(
            os.path.getsize(os.path.join(dp, f))
            for dp, _, fns in os.walk(export_dir) for f in fns
        )
        if total_size >= 1024**3:
            size_str = f"{total_size / 1024**3:.2f} GB"
        else:
            size_str = f"{total_size / 1024**2:.2f} MB"
        print_rank_0(f"Export size: {size_str} ({export_dir})")

    print_rank_0("Done! Full VLM exported to HF format.")


if __name__ == "__main__":
    main()

Testing

  • Verified round-trip import/export with Qwen3-VL-8B-Instruct with the example usage above
  • Unit tests in tests/unit/torch/export/test_mcore_qwen3vl.py
    covering:
    • Registration in global export/import mappings
    • Import mapping: dense keys, model.language_model.
      prefix, lm_head. at root, QKVMerging, GatedMLPMerging, REPLICATE
      for layernorms, TP sharding configs
    • Export mapping: QKVSlicing, GatedMLPSlicing, no
      parallel_config
    • Import/export symmetry: same mcore keys, matching HF
      prefixes
    • Qwen3-VL vs Qwen3 difference: same keys, VL adds
      language_model. prefix, lm_head unchanged

Before your PR is "Ready for review"

  • Is this change backward compatible?: Yes, additive only
  • Did you write any new necessary tests?: Yes, tests/unit/torch/export/test_mcore_qwen3vl.py
  • Did you add or update any necessary documentation? Yes, see docs/source/deployment/3_unified_hf.rst
  • Did you update Changelog? Yes, see CHANGELOG.rst

Additional Information

Companion Megatron-LM PR adds Qwen3VLModel, Qwen3VLDataset, and pretrain_qwenvl.py. Please see this PR NVIDIA/Megatron-LM#3444

Summary by CodeRabbit

  • New Features

    • Added Megatron Core export/import mapping support for Qwen3‑VL, handling language‑model weight prefixes and supporting dense and MoE variants.
  • Documentation

    • Added Qwen3‑VL (FP8 / NVFP4) to the deployment support matrix for TensorRT‑LLM.
  • Tests

    • Added GPU tests verifying registration, import/export mapping symmetry, and expected layer transformation behavior for Qwen3‑VL.

Review Change Stack

@hychiang-git hychiang-git requested a review from a team as a code owner May 13, 2026 19:21
@hychiang-git hychiang-git requested a review from ChenhanYu May 13, 2026 19:21
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 13, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 13, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 88a99c87-0e1f-4dd2-84c7-cd84a6749211

📥 Commits

Reviewing files that changed from the base of the PR and between ff1152f and e8101a7.

📒 Files selected for processing (3)
  • modelopt/torch/export/plugins/mcore_common.py
  • modelopt/torch/export/plugins/mcore_qwen3vl.py
  • tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/export/plugins/mcore_common.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • modelopt/torch/export/plugins/mcore_qwen3vl.py
  • tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py

📝 Walkthrough

Walkthrough

Adds bidirectional Megatron Core ↔ Hugging Face weight mappings for Qwen3-VL, registers them as a plugin, includes tests validating import/export symmetry and prefix rules, and updates changelog and deployment docs for Qwen 3‑VL support.

Changes

Qwen3-VL Megatron Core Integration

Layer / File(s) Summary
Mapping Module Definition
modelopt/torch/export/plugins/mcore_qwen3vl.py
New module introduces qwen3vl_causal_lm_import and qwen3vl_causal_lm_export dictionaries describing HF↔Megatron Core weight conversion, including prefix adjustments (model.language_model. vs root lm_head.), QKV/MLP merged-projection handling, and MoE expert slicing/routing.
Plugin Registration and Wiring
modelopt/torch/export/plugins/mcore_common.py
Imports the new Qwen3-VL mapping functions and registers Qwen3VLForConditionalGeneration in both global export and import mapping dictionaries, wiring the model type to the new conversion handlers.
Test Suite and Validation
tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py
Comprehensive test suite with registration checks, key-presence validation for dense and MoE parameters, prefix behavior verification, transformation type checks for QKV/MLP, export-mapping func_kwargs checks, symmetry validation between import/export, and comparative assertions against Qwen3.
Documentation Updates
CHANGELOG.rst, docs/source/deployment/3_unified_hf.rst
Changelog entry documenting Qwen3-VL Megatron Core support and deployment documentation noting TensorRT-LLM compatibility with FP8 and NVFP4 quantization formats.

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately summarizes the main change: adding Qwen3VL MCore export support, which is directly reflected in the new module, mappings, tests, and documentation updates.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns found. All PR files reviewed against SECURITY.md with no issues detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch hungyueh/pr-895

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@CHANGELOG.rst`:
- Line 136: Move the changelog entry "Add Megatron Core export/import mapping
for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models..."
out of the 0.42 (2026-03-10) section and place it under the current
unreleased/0.45 section header in CHANGELOG.rst, preserving the existing
formatting and inline code markup; ensure you remove the duplicate from 0.42 and
verify the entry appears exactly once under the 0.45 (unreleased) section.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 7f30ba6d-8de4-4386-b197-7f8189e56d24

📥 Commits

Reviewing files that changed from the base of the PR and between 62401e1 and 2423ae7.

📒 Files selected for processing (5)
  • CHANGELOG.rst
  • docs/source/deployment/3_unified_hf.rst
  • modelopt/torch/export/plugins/mcore_common.py
  • modelopt/torch/export/plugins/mcore_qwen3vl.py
  • tests/gpu_megatron/torch/export/test_mcore_qwen3vl.py

Comment thread CHANGELOG.rst Outdated
Comment thread CHANGELOG.rst Outdated
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 13, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1482/

Built to branch gh-pages at 2026-05-13 23:19 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@hychiang-git
Copy link
Copy Markdown
Contributor Author

/claude review

@claude
Copy link
Copy Markdown

claude Bot commented May 13, 2026

Claude Review Summary

Small, additive PR that clones the Qwen3 mcore mapping with model.model.language_model. substitution and registers Qwen3VLForConditionalGeneration. Mechanically straightforward; the risk is concentrated in a couple of places.

Findings

  • CRITICAL: 0
  • IMPORTANT: 2
  • SUGGESTION: 1

Most impactful

  1. MoE arch not registered (mcore_qwen3vl.py) — the file ships MoE rules (router, local_experts.*) and the comment claims support for "Qwen3-VL MoE variants like 30B-A3B", but only the dense Qwen3VLForConditionalGeneration arch is wired up in mcore_common.py. Either add the MoE arch entry (mirroring Qwen3MoeForCausalLM) or drop the MoE rules to avoid dead code that implies unsupported behavior.

  2. lm_head placement — root-level lm_head. is inherited from mcore_qwen.py, but for several recent *ForConditionalGeneration VLMs in transformers, lm_head lives at model.language_model.lm_head.. If that holds for the Qwen3-VL release you target, import silently misses the tensor and export writes to the wrong key. PR description says round-trip was verified, so this may be fine — but worth a one-time safe_open(...).keys() confirmation.

  3. Test placementtest_mcore_qwen3vl.py is dict-shape inspection only, doesn't need GPU or Megatron, but lives in tests/gpu_megatron/. Belongs in tests/unit/torch/export/ so it runs in the fast pre-merge lane (matches what the PR description said). Also note these tests assert "the dict we wrote equals the dict we wrote" — a small integration test against a real HF state-dict snapshot would catch the lm_head issue above.

Risk: Low-to-moderate. Code is purely additive, no existing arch behavior changes. Worst case is a broken Qwen3-VL round-trip that only manifests at runtime — which is exactly why the test placement matters.

"local_experts.linear_fc2": NameRemapping(
"model.language_model.layers.{}.mlp.experts.{}.down_proj."
),
} No newline at end of file
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Compatibility] Only Qwen3VLForConditionalGeneration is registered in mcore_common.py, but this mapping intentionally includes MoE rules (router, local_experts.linear_fc1, local_experts.linear_fc2) "for Qwen3-VL MoE variants like 30B-A3B" (see line 81). Following the Qwen3 precedent (Qwen3ForCausalLM + Qwen3MoeForCausalLM are both registered to the same dict in mcore_common.py:55-56), the MoE arch (likely Qwen3VLMoeForConditionalGeneration) needs its own entry — otherwise the dense arch lookup will fail and MoE checkpoints won't dispatch here. Either:

  1. Add "Qwen3VLMoeForConditionalGeneration": qwen3vl_causal_lm_{import,export} to mcore_common.py, or
  2. If MoE support is out of scope for this PR, drop the MoE-only entries (router, local_experts.*) here and the comment on line 81 to avoid implying support that isn't wired up.

Right now the MoE rules are dead code.


from modelopt.torch.export.plugins.mcore_custom import (
COL_TP,
REPLICATE,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] This test file only inspects the static dict mappings — no GPU, no Megatron-Core, and no actual model are required to run it. mcore_custom.py and mcore_qwen3vl.py import megatron lazily via import_plugin("megatron") so the imports work without it. Placing this in tests/gpu_megatron/ means the only signal protecting these mappings runs in the most expensive CI lane. Consider moving to tests/unit/torch/export/test_mcore_qwen3vl.py (which is what the PR description originally said) so the registration/symmetry checks run in the fast pre-merge lane.

Note also that what's tested here is essentially "the dict literal we wrote matches the dict literal we wrote" — these checks won't catch the real correctness risk, which is whether the HF key layout (especially lm_head. at root vs. model.language_model.lm_head.) actually matches the published Qwen3-VL checkpoint. A small integration test that loads a tiny config and asserts no missing/unexpected keys against an HF state-dict snapshot would be far more valuable.

# Final layer norm
"final_layernorm": NameRemapping("model.language_model.norm.", REPLICATE),
# Output layer (lm_head is at root level, not under language_model)
"output_layer": NameRemapping("lm_head.", COL_TP),
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[IMPORTANT Export] Worth double-checking against the published Qwen3-VL checkpoint: in recent transformers (≥4.45), several *ForConditionalGeneration VLMs (including the Qwen2.5-VL / Qwen3-VL families) moved lm_head into the inner language model — i.e. the safetensors key is model.language_model.lm_head.weight, not lm_head.weight at root. If that's the case for the Qwen3-VL release you're targeting, both output_layer mappings (here and line 98) will silently fail to find the tensor on import and write to the wrong location on export, and tie_word_embeddings interaction will also be off.

The PR description says you've round-tripped Qwen3-VL-8B-Instruct, so this may already be verified — but the Qwen3 mapping (mcore_qwen.py:35) inherited a root-level lm_head. from a different architecture pattern, and copying it without checking is the most likely place this PR could be wrong. Worth grepping the actual safetensors keys (safe_open(...).keys()) once and confirming.

@codecov
Copy link
Copy Markdown

codecov Bot commented May 13, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.77%. Comparing base (229ba61) to head (e8101a7).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1482   +/-   ##
=======================================
  Coverage   76.77%   76.77%           
=======================================
  Files         473      474    +1     
  Lines       51418    51422    +4     
=======================================
+ Hits        39476    39481    +5     
+ Misses      11942    11941    -1     
Flag Coverage Δ
unit 52.55% <100.00%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Add Megatron Core export/import mapping for Qwen3-VL
(Qwen3VLForConditionalGeneration). Handles the model.language_model.
weight prefix and supports both dense and MoE variants.

Signed-off-by: Hung-Yueh <hungyuehc@nvidia.com>

mv test_mcore_qwen3vl.py to tests/gpu_megatron/torch/export/

Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
hychiang-git and others added 3 commits May 13, 2026 21:39
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants