From b33e047f9fb7df768a0f473cd969a26401591fd4 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 26 May 2026 21:02:04 +0000 Subject: [PATCH 1/8] [Fix] Forward rms_norm_type to DenseDecoderLayer in Dense.build_layers Dense.build_layers omitted rms_norm_type when constructing DenseDecoderLayer, so a zero_centered dense model silently used the default RMSNorm for the per-layer input/post norms. MoE.build_layers already forwards it. No-op for existing dense models, which use the default rms_norm_type. --- xtuner/v1/model/dense/dense.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index 08e7dce51..247009c80 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -149,6 +149,7 @@ def build_layers(self, config: TransformerConfig) -> nn.ModuleDict: mlp_bias=config.mlp_bias, hidden_act=config.hidden_act, rms_norm_eps=config.rms_norm_eps, + rms_norm_type=config.rms_norm_type, attention_config=attention_config, generate_config=config.generate_config, rope_scaling_cfg=config.rope_scaling_cfg, From d410b24d4ee61839cb6c84b1006c83a92bf9f2d2 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 26 May 2026 21:02:15 +0000 Subject: [PATCH 2/8] [Feature] Add Qwen3.5-4B dense VLM support Qwen3.5-4B is a dense VLM whose text tower is a hybrid of GatedDeltaNet linear attention and gated MHA full attention (every 4th layer), reusing the existing Dense body's per-layer attention dispatch. Adds the dense text tower (Qwen3_5_VLTextDense + config), the Qwen3_5_VLDense4BConfig compose config (reusing the existing Qwen3.5 vision/projector towers), and registration. MTP is deferred: the checkpoint's mtp.* keys are skipped on load (matching HF) and not re-saved. See docs/design/model/qwen3_5_dense_4b.md. --- ci/config/qwen3_5_dense4B.py | 73 +++++++++ docs/design/model/qwen3_5_dense_4b.md | 152 ++++++++++++++++++ xtuner/v1/model/__init__.py | 4 +- xtuner/v1/model/compose/qwen3_5/__init__.py | 3 +- .../model/compose/qwen3_5/qwen3_5_config.py | 24 ++- .../model/compose/qwen3_vl/modeling_vision.py | 5 + xtuner/v1/model/dense/qwen3_5_text.py | 98 +++++++++++ 7 files changed, 351 insertions(+), 8 deletions(-) create mode 100644 ci/config/qwen3_5_dense4B.py create mode 100644 docs/design/model/qwen3_5_dense_4b.md create mode 100644 xtuner/v1/model/dense/qwen3_5_text.py diff --git a/ci/config/qwen3_5_dense4B.py b/ci/config/qwen3_5_dense4B.py new file mode 100644 index 000000000..1a223438c --- /dev/null +++ b/ci/config/qwen3_5_dense4B.py @@ -0,0 +1,73 @@ +# Smoke SFT config for Qwen3.5-VL Dense 4B's text tower — drives the full +# from_hf -> model.forward -> loss -> backward -> optimizer step -> FSDP shard/reduce +# chain on a real checkpoint so the port can be verified end-to-end. Loss should drop +# monotonically into a plausible SFT range within the first ~50 steps. +# +# Usage (single node, 8 GPUs): +# +# export QWEN3_5_DENSE_4B_PATH=/path/to/Qwen3.5-4B +# export ALPACA_PATH=/path/to/alpaca +# torchrun --nproc-per-node=8 -m xtuner.v1.train.cli.sft --config ci/config/qwen3_5_dense4B.py +# +# This smokes the text-only path on alpaca (matches the convention of +# ``ci/config/qwen3_moe_30BA3.py``); the full VLM compose forward — image embed +# injection through ``_prepare_llm_inputs`` — is owned by +# ``tests/model/test_qwen3_5_dense.py::test_model_forward_bitwise_reduced_layers``. +import os + +from xtuner.v1.config import ( + AdamWConfig, + FSDPConfig, + LRConfig, +) +from xtuner.v1.datasets import FTDPTokenizeFnConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.loss.ce_loss import CELossConfig +from xtuner.v1.model.dense.qwen3_5_text import Qwen3_5_VLTextDense4BConfig +from xtuner.v1.train import TrainerConfig + + +QWEN3_5_DENSE_4B_PATH = os.environ["QWEN3_5_DENSE_4B_PATH"] +ALPACA_PATH = os.environ["ALPACA_PATH"] + + +model_cfg = Qwen3_5_VLTextDense4BConfig() +optim_cfg = AdamWConfig(lr=6e-05) +lr_cfg = LRConfig(lr_type="cosine", lr_min=1e-6) +fsdp_cfg = FSDPConfig( + cpu_offload=False, +) + +dataset_config = [ + { + "dataset": DatasetConfig(name="alpaca", anno_path=ALPACA_PATH, sample_ratio=1.0), + "tokenize_fn": FTDPTokenizeFnConfig(max_length=16384), + }, +] + +dataloader_config = DataloaderConfig( + pack_max_length=16384, +) + +loss_cfg = CELossConfig() + + +trainer = TrainerConfig( + load_from=QWEN3_5_DENSE_4B_PATH, + model_cfg=model_cfg, + optim_cfg=optim_cfg, + fsdp_cfg=fsdp_cfg, + dataset_cfg=dataset_config, + dataloader_cfg=dataloader_config, + lr_cfg=lr_cfg, + loss_cfg=loss_cfg, + tokenizer_path=QWEN3_5_DENSE_4B_PATH, + global_batch_size=16, + total_epoch=1, + work_dir="/tmp/qwen3_5_dense4B", + seed=0, + # XTuner defers Qwen3.5 MTP weights, so the ckpt has ~15 ``mtp.*`` keys with no + # matching XTuner params; loosen strict load to skip them. Same reasoning as the + # save_hf round-trip test scoping out ``mtp.*``. + strict_load=False, +) diff --git a/docs/design/model/qwen3_5_dense_4b.md b/docs/design/model/qwen3_5_dense_4b.md new file mode 100644 index 000000000..bee88d1e7 --- /dev/null +++ b/docs/design/model/qwen3_5_dense_4b.md @@ -0,0 +1,152 @@ +# Qwen3.5-4B (dense VLM) integration design + +## 1. What the model is + +`Qwen/Qwen3.5-4B` — `model_type: "qwen3_5"`, architecture +`Qwen3_5ForConditionalGeneration`. A **vision-language model** whose text tower is +a **dense hybrid** of linear (GatedDeltaNet) and full (gated MHA) attention. + +| Tower | Shape | +|-------|-------| +| Vision (`model.visual.*`) | Qwen3VL-style: `patch_embed.proj`, `pos_embed`, `blocks.{N}` (attn qkv/proj, mlp linear_fc1/fc2, norm1/norm2), `merger`. depth 24, hidden 1024, `out_hidden_size` 2560, `num_position_embeddings` 2304, `deepstack_visual_indexes: []` | +| Projector (`model.visual.merger`) | vision_hidden 1024 → text_hidden 2560 | +| Text (`model.language_model.*`) | dense, 32 layers, hidden 2560, head_dim 256, 16 q-heads, 4 kv-heads, `intermediate_size` 9216 | + +Text per-layer attention follows `full_attention_interval = 4`: layers where +`(i+1) % 4 == 0` (3, 7, 11, …) are `full_attention`; the rest are +`linear_attention`. + +- **full_attention layer**: gated MHA — `q_proj` emits `head_dim*2` then chunks + into `(query, gate)`, and the attention output is multiplied by + `sigmoid(gate)` before `o_proj`. Has `q_norm`/`k_norm`. **No sliding window** + (global attention). → maps to `MHAConfig(with_gate=True, qk_norm=True)`. +- **linear_attention layer**: `Qwen3_5GatedDeltaNet` — `in_proj_qkv`, + `in_proj_z`, `in_proj_b`, `in_proj_a`, depthwise `conv1d`, `A_log`, `dt_bias`, + gated RMSNorm `norm`, `out_proj`. → maps to `GatedDeltaNetConfig`. +- **RoPE**: `mrope_interleaved: true`, `mrope_section: [11, 11, 10]`, + `rope_type: "default"`, `rope_theta: 1e7`, `partial_rotary_factor: 0.25`. +- **MTP**: 1 dense full-attention MTP layer in the checkpoint (15 `mtp.*` + weights). **Deferred** (see §4). +- `tie_word_embeddings: true`. + +## 2. Reuse map — what already exists + +The MoE sibling (`Qwen/Qwen3.5-…-A3B`) is **already ported**, so almost every +building block exists: + +- **Hybrid dispatch** — `Dense.build_layers` (`xtuner/v1/model/dense/dense.py`) + already selects `config.linear_attention` (`GatedDeltaNetConfig`) for + `linear_attention` layers and `config.attention` for `full_attention`. No body + change needed for the hybrid. +- **Gated MHA / GatedDeltaNet / gated-deltanet ops** — exist + (`xtuner/v1/module/attention/{mha,gated_deltanet}.py`, `xtuner/v1/ops/gated_deltanet/`). +- **Vision + projector** — `Qwen3_5_VisionConfig` / `Qwen3_5_ProjectorConfig` + (`xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py`) already exist with + `deepstack_visual_indexes: []`; only need the 4B dims. +- **Compose base** — `Qwen3_5_BaseConfig` exists; its `from_hf`/`save_hf` come + from `Qwen3VLBaseConfig` / `BaseComposeModel`. + +The **only genuinely new code** is the *dense* hybrid text tower + its config, +plus a 4B compose config and registration. + +## 3. New code (kept inside the three seams) + +### 3.1 Text tower — `xtuner/v1/model/dense/qwen3_5_text.py` + +- `Qwen3_5_VLTextDense(Qwen3VLTextDense)` — reuse the deepstack-aware dense + forward (`xtuner/v1/model/dense/qwen3vl_text.py`); deepstack is inert here + (`deepstack_visual_indexes: []`). Override only `to_hf_key_list`: + - prefix `model.language_model.` for `layers.*` / `embed_tokens`; + - for `linear_attention` layers, rename `self_attn` → `linear_attn` + (same rule as the MoE tower, driven by `config.layers_type[idx]`); + - top-level `norm.` → `model.language_model.norm.`; + - `tie_word_embeddings`: redirect `lm_head` → `embed_tokens`. + - No MoE fusion / `safetensors_to_params` needed (plain dense MLP). +- `Qwen3_5_VLTextDenseConfig(TransformerConfig)` — `layers_type` computed + (`full` every 4th), `attention = MHAConfig(with_gate=True, qk_norm=True, + head_dim=256, num_attention_heads=16, num_key_value_heads=4)` **without** + `sliding_window`, `linear_attention = GatedDeltaNetConfig(...)`, + `rope_parameters_cfg` for interleaved partial-rotary mrope, `rms_norm_type` + per parity check (§5). +- `Qwen3_5_VLTextDense4BConfig` — hard-coded 4B dims. + +### 3.2 Compose config — `compose/qwen3_5/qwen3_5_config.py` + +- Add `Qwen3_5_VLDense4BConfig(Qwen3_5_BaseConfig)`: vision `Qwen3_5_VisionConfig(depth=24, hidden_size=1024, out_hidden_size=2560)`, + projector `Qwen3_5_ProjectorConfig(vision_hidden_size=1024, text_hidden_size=2560)`, + `text_config = Qwen3_5_VLTextDense4BConfig(...)`. +- Widen `Qwen3_5_BaseConfig.text_config` type from `MoEConfig` to the shared base + (`TransformerConfig` / `XTunerBaseModelConfig`) so a dense text config is + accepted. (Small, localized interface widening — no behavior change.) + +### 3.3 Registration — `xtuner/v1/model/__init__.py` + +Import `Qwen3_5_VLDense4BConfig`, add a `model_mapping` alias +(`"qwen3_5-vl-dense-4b"`), extend `__all__`. Consistent with the existing +hard-coded `qwen3_5` family (no `get_model_config_from_hf` dispatch is added for +this family today). + +## 4. MTP — deferred (decision) + +Baseline ships **without** MTP. The 15 `mtp.*` checkpoint keys are not built, so +`from_hf` reports them as unexpected (matching HF, which lists +`_keys_to_ignore_on_load_unexpected = [r"^mtp.*"]`), and `save_hf` does not +re-emit them. The round-trip test is scoped to non-`mtp.*` keys. Adding MTP to +the Dense path (mirroring `MoE.build_mtp_block` + forward/loss integration) is a +follow-up commit. + +## 5. Parity result (decoder-layer bitwise) + +Verified empirically (single GPU, bf16) against HF `Qwen3_5ForConditionalGeneration`'s +`language_model`. Findings: + +- **GatedDeltaNet (linear) layers — bitwise (0.0)** out of the box: XTuner and HF + share the same `fla` / `causal_conv1d` kernels. +- **RoPE — bitwise (0.0)**: `rope_type="qwen3_vl"` reproduces HF's interleaved + partial-rotary mrope exactly. RMSNorm `zero_centered` matches HF's + `output * (1 + weight)`. Gated-MHA q/gate chunk + `sigmoid(gate)` matches. +- **Full-attention layers** only diverged because XTuner defaults to **flash + attention** while HF eager upcasts softmax to fp32. XTuner's own + `eager_attention` op already matches HF's `eager_attention_forward` **bitwise**. + +So the only thing needed for bitwise parity is forcing the eager attention path. +That is what the **`XTUNER_HF_IMPL`** switch does (§5.1). + +Discovered + fixed along the way: `Dense.build_layers` did not forward +`rms_norm_type` to `DenseDecoderLayer` (MoE did), so a `zero_centered` dense model +would silently use the default RMSNorm. Fixed generally (no-op for existing +`"default"` dense models). + +Parity outcome: with `XTUNER_HF_IMPL=true`, all 32 decoder layers (linear + full) +and the final-norm hidden state match HF **bitwise (max diff 0.0)**; logits match +bitwise. On the default flash path, the full-model forward matches HF within +`1e-2`. + +### 5.1 `XTUNER_HF_IMPL` switch (ops-level) + +New env var `XTUNER_HF_IMPL` (`xtuner/v1/utils/misc.py`) selects HF-exact op +implementations, patched **only at the ops layer** (never the decoder-layer +forward): + +- `xtuner/v1/ops/attn_imp.py::get_attn_impl_fn` returns `eager_attention` + regardless of the configured backend. +- `xtuner/v1/ops/rms_norm/__init__.py::get_rms_norm_fn` forces the native torch + path (over triton). + +Both read the env var live so a test can toggle it per model instance. `mha.py` +selects its attention op through `get_attn_impl_fn`. + +## 6. Tests — `tests/model/test_qwen3_5_dense.py` + +- Decoder-layer bitwise parity: one linear + one full layer vs HF. +- `save_hf` round-trip (byte-equal, non-`mtp.*` keys). +- Checkpoint path from env var (e.g. `QWEN3_5_DENSE_4B_PATH`). + +## 7. Commit plan (stacked) + +1. Dense text tower + config + compose 4B config + registration (baseline, + no MTP). +2. Baseline tests (decoder-layer parity + save_hf round-trip). +3. (later) MTP support in the Dense path. +4. (later) §8 optimizations: EP n/a (dense), SP, torch.compile, fp8, + activation offload — each with a comparison test. diff --git a/xtuner/v1/model/__init__.py b/xtuner/v1/model/__init__.py index 32cc7fa69..b8d794300 100644 --- a/xtuner/v1/model/__init__.py +++ b/xtuner/v1/model/__init__.py @@ -11,7 +11,7 @@ InternVL3P5MoE30BA3Config, InternVLBaseConfig, ) -from .compose.qwen3_5 import Qwen3_5_VLMoE35BA3Config +from .compose.qwen3_5 import Qwen3_5_VLDense4BConfig, Qwen3_5_VLMoE35BA3Config from .compose.qwen3_vl import ( Qwen3VLDense4BConfig, Qwen3VLDense8BConfig, @@ -38,6 +38,7 @@ "internvl-3.5-8b-hf": InternVL3P5Dense8BConfig(), "internvl-3.5-1b-hf": InternVL3P5Dense1BConfig(), "internvl-3.5-30b-a3b-hf": InternVL3P5MoE30BA3Config(), + "qwen3.5-vl-4b": Qwen3_5_VLDense4BConfig(), } @@ -101,4 +102,5 @@ def get_model_config_from_hf(model_path: Path): "DEFAULT_FLOAT8_CFG", "XTunerBaseModelConfig", "Qwen3_5_VLMoE35BA3Config", + "Qwen3_5_VLDense4BConfig", ] diff --git a/xtuner/v1/model/compose/qwen3_5/__init__.py b/xtuner/v1/model/compose/qwen3_5/__init__.py index 27a424d88..53822e6ee 100644 --- a/xtuner/v1/model/compose/qwen3_5/__init__.py +++ b/xtuner/v1/model/compose/qwen3_5/__init__.py @@ -1,7 +1,8 @@ -from .qwen3_5_config import Qwen3_5_VLMoE35BA3Config, Qwen3_5_VLMoE35BA3SplitConfig +from .qwen3_5_config import Qwen3_5_VLDense4BConfig, Qwen3_5_VLMoE35BA3Config, Qwen3_5_VLMoE35BA3SplitConfig __all__ = [ + "Qwen3_5_VLDense4BConfig", "Qwen3_5_VLMoE35BA3Config", "Qwen3_5_VLMoE35BA3SplitConfig", ] diff --git a/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py b/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py index 33b33c69c..5b8f56f8d 100644 --- a/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py +++ b/xtuner/v1/model/compose/qwen3_5/qwen3_5_config.py @@ -1,8 +1,9 @@ -from xtuner.v1.model.moe.moe import MoEConfig -from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig +from xtuner.v1.model.dense.qwen3_5_text import Qwen3_5_VLTextDense4BConfig, Qwen3_5_VLTextDenseConfig +from xtuner.v1.model.moe.qwen3_5_text import Qwen3_5_VLTextMoE35BA3BConfig, Qwen3_5_VLTextMoEConfig from xtuner.v1.model.moe.qwen3_5_text_split import ( Qwen3_5_VLTextMoE35BA3BSplitConfig, Qwen3_5_VLTextMoE397BA17BSplitConfig, + Qwen3_5_VLTextMoESplitConfig, ) from xtuner.v1.utils import get_logger @@ -20,10 +21,13 @@ class Qwen3_5_ProjectorConfig(Qwen3VLProjectorConfig): deepstack_visual_indexes: list[int] = [] +Qwen3_5TextConfig = Qwen3_5_VLTextDenseConfig | Qwen3_5_VLTextMoEConfig | Qwen3_5_VLTextMoESplitConfig + + class Qwen3_5_BaseConfig(Qwen3VLBaseConfig): vision_config: Qwen3_5_VisionConfig projector_config: Qwen3_5_ProjectorConfig - text_config: MoEConfig + text_config: Qwen3_5TextConfig image_token_id: int = 248056 video_token_id: int = 248057 @@ -34,13 +38,21 @@ class Qwen3_5_BaseConfig(Qwen3VLBaseConfig): class Qwen3_5_VLMoE35BA3Config(Qwen3_5_BaseConfig): vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig() projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig() - text_config: MoEConfig = Qwen3_5_VLTextMoE35BA3BConfig() + text_config: Qwen3_5_VLTextMoE35BA3BConfig = Qwen3_5_VLTextMoE35BA3BConfig() + + +class Qwen3_5_VLDense4BConfig(Qwen3_5_BaseConfig): + vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig(depth=24, hidden_size=1024, intermediate_size=4096) + projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig(vision_hidden_size=1024, text_hidden_size=2560) + text_config: Qwen3_5_VLTextDense4BConfig = Qwen3_5_VLTextDense4BConfig() class Qwen3_5_VLMoE35BA3SplitConfig(Qwen3_5_BaseConfig): vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig() projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig() - text_config: MoEConfig = Qwen3_5_VLTextMoE35BA3BSplitConfig(hf_key_mapping={r"^model\.": "model.language_model."}) + text_config: Qwen3_5_VLTextMoE35BA3BSplitConfig = Qwen3_5_VLTextMoE35BA3BSplitConfig( + hf_key_mapping={r"^model\.": "model.language_model."} + ) class Qwen3_5TimeSeriesMoE35BA3Config(Qwen3_5_VLMoE35BA3Config): @@ -51,6 +63,6 @@ class Qwen3_5TimeSeriesMoE35BA3Config(Qwen3_5_VLMoE35BA3Config): class Qwen3_5_VLMoE397BA17SplitConfig(Qwen3_5_BaseConfig): vision_config: Qwen3_5_VisionConfig = Qwen3_5_VisionConfig(fully_shard=False) projector_config: Qwen3_5_ProjectorConfig = Qwen3_5_ProjectorConfig(text_hidden_size=4096, fully_shard=False) - text_config: MoEConfig = Qwen3_5_VLTextMoE397BA17BSplitConfig( + text_config: Qwen3_5_VLTextMoE397BA17BSplitConfig = Qwen3_5_VLTextMoE397BA17BSplitConfig( hf_key_mapping={r"^model\.": "model.language_model."} ) diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py b/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py index 5be2ee5f6..b3f70eadb 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py @@ -83,6 +83,11 @@ class Qwen3VLVisionRotaryEmbedding(nn.Module): def __init__(self, dim: int, theta: float = 10000.0) -> None: super().__init__() + # Mirror HF's Qwen3_5VisionRotaryEmbedding: keep `dim` and `theta` so the buffer can be + # rebuilt after a meta-build → `to_empty` round-trip (the init-computed `inv_freq` would + # otherwise be left garbage). See `DeterministicDDPTestCase.materialize_submodule`. + self.dim = dim + self.theta = theta inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) diff --git a/xtuner/v1/model/dense/qwen3_5_text.py b/xtuner/v1/model/dense/qwen3_5_text.py new file mode 100644 index 000000000..e56d2aced --- /dev/null +++ b/xtuner/v1/model/dense/qwen3_5_text.py @@ -0,0 +1,98 @@ +import re +from typing import Literal + +from pydantic import Field, computed_field + +from xtuner.v1.model.base import HFSaveCfg, TransformerConfig +from xtuner.v1.module.attention import GatedDeltaNetConfig, MHAConfig +from xtuner.v1.module.rope import RopeParametersConfig + +from .qwen3vl_text import Qwen3VLTextDense + + +class Qwen3_5_VLTextDense(Qwen3VLTextDense): + def to_hf_key_list(self, key: str) -> list[str]: + # Emit the standalone language-model layout (``model.<...>``). The VLM nesting + # under ``model.language_model.`` is applied by the config's ``hf_key_mapping`` + # so this tower stays unaware of how it is composed into the VLM. + if self.config.tie_word_embeddings and "lm_head" in key: + key = key.replace("lm_head", "embed_tokens") + + if "layers" in key: + # HF stores the GatedDeltaNet under ``linear_attn`` while XTuner keeps the + # generic ``self_attn`` attribute name regardless of attention type, so the + # rename is driven by the per-layer ``layers_type``. + layer_idx = int(re.findall(r"layers\.(\d+)\.", key)[0]) + if self.config.layers_type[layer_idx] == "linear_attention": + key = key.replace("self_attn", "linear_attn") + + if "layers" in key or "embed_tokens" in key: + key = "model." + key + + if key.startswith("norm."): + return [key.replace("norm.", "model.norm.")] + else: + return [key] + + +class Qwen3_5_VLTextDenseConfig(TransformerConfig): + rms_norm_type: Literal["default", "zero_centered"] = "zero_centered" + # The dense text tower emits the standalone ``model.<...>`` layout; this remaps it + # to the VLM's ``model.language_model.<...>`` namespace on both load and save. + hf_key_mapping: dict[str, str] | None = {r"^model\.": "model.language_model."} + # Qwen3.5 keeps the GatedDeltaNet gated-RMSNorm weight and the per-head decay + # parameter ``A_log`` in fp32; the rest of the model runs in bf16. + hf_save_cfg: HFSaveCfg = HFSaveCfg( + fp32_keys_pattern=[ + r"model\.language_model\.layers\.\d+\.linear_attn\.norm\.weight", + r"model\.language_model\.layers\.\d+\.linear_attn\.A_log", + ], + ) + + @computed_field + def layers_type(self) -> list[Literal["full_attention", "linear_attention"]]: + # ``full_attention_interval`` == 4: every 4th layer (idx 3, 7, ...) is full + # attention, the rest are linear (GatedDeltaNet). + return ["full_attention" if (i + 1) % 4 == 0 else "linear_attention" for i in range(self.num_hidden_layers)] + + def build(self) -> Qwen3_5_VLTextDense: + return Qwen3_5_VLTextDense(self) + + +class Qwen3_5_VLTextDense4BConfig(Qwen3_5_VLTextDenseConfig): + vocab_size: int = 248320 + max_position_embeddings: int = 262144 + pad_token_id: int | None = None + eos_token_id: int = 248044 + num_hidden_layers: int = 32 + hidden_size: int = 2560 + intermediate_size: int = 9216 + rms_norm_eps: float = 1e-6 + hidden_act: str = "silu" + tie_word_embeddings: bool = True + attention: MHAConfig = MHAConfig( + with_gate=True, + num_attention_heads=16, + num_key_value_heads=4, + head_dim=256, + qk_norm=True, + rms_norm_eps=1e-6, + rms_norm_type="zero_centered", + ) + linear_attention: GatedDeltaNetConfig = GatedDeltaNetConfig( + num_value_heads=32, + num_key_heads=16, + key_head_dim=128, + value_head_dim=128, + conv_kernel_dim=4, + hidden_act="silu", + rms_norm_eps=1e-6, + ) + rope_parameters_cfg: RopeParametersConfig = Field( + default_factory=lambda: RopeParametersConfig( + rope_theta=10000000.0, + rope_type="qwen3_vl", + mrope_section=[11, 11, 10], + partial_rotary_factor=0.25, + ) + ) From 3cbf32d01f3317654b8b6de1a9ff04146c3670d2 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 26 May 2026 21:02:25 +0000 Subject: [PATCH 3/8] [Feature] Add XTUNER_HF_IMPL op-level switch for HF-exact parity XTUNER_HF_IMPL selects HuggingFace-exact op implementations at the ops layer so decoder layers can be aligned bitwise against transformers: get_attn_impl_fn forces the eager attention path (XTuner's eager_attention matches HF's eager_attention_forward bitwise, fp32 softmax + dense causal mask), and get_rms_norm_fn forces the native torch path over triton. Both read the env var live so tests can toggle it per model instance. Not for production training. --- .../model/compose/qwen3_vl/modeling_vision.py | 6 +-- xtuner/v1/module/attention/gated_deltanet.py | 9 ++-- xtuner/v1/module/attention/mha.py | 4 +- xtuner/v1/ops/__init__.py | 2 +- xtuner/v1/ops/attn_imp.py | 46 ++++++++++++++-- xtuner/v1/ops/gated_deltanet/__init__.py | 54 +++++++++++++++++++ xtuner/v1/ops/rms_norm/__init__.py | 4 +- xtuner/v1/utils/__init__.py | 2 + xtuner/v1/utils/misc.py | 5 ++ 9 files changed, 115 insertions(+), 17 deletions(-) create mode 100644 xtuner/v1/ops/gated_deltanet/__init__.py diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py b/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py index b3f70eadb..d9b599b78 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py @@ -7,7 +7,7 @@ from typing_extensions import override from .qwen3_vl_config import Qwen3VLVisionConfig from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_device, get_torch_device_module, init_params -from xtuner.v1.ops.attn_imp import attn_impl_mapping, AttnOpOutputs +from xtuner.v1.ops.attn_imp import AttnOpOutputs, get_attn_impl_fn import torch.nn.functional as F from pathlib import Path from xtuner.v1.model import BaseModel @@ -130,9 +130,7 @@ def __init__(self, config: Qwen3VLVisionConfig) -> None: self.scale = self.head_dim ** -0.5 self.config = config self.attention_dropout = 0.0 - self.attn_impl_func: Callable[..., AttnOpOutputs] = ( - attn_impl_mapping[config.attn_impl] # type: ignore[assignment] - ) + self.attn_impl_func: Callable[..., AttnOpOutputs] = get_attn_impl_fn(config.attn_impl) # type: ignore[assignment] def forward( self, diff --git a/xtuner/v1/module/attention/gated_deltanet.py b/xtuner/v1/module/attention/gated_deltanet.py index 8f23bf60a..5df011a35 100644 --- a/xtuner/v1/module/attention/gated_deltanet.py +++ b/xtuner/v1/module/attention/gated_deltanet.py @@ -15,8 +15,7 @@ from xtuner.v1.ops.comm.all_to_all import ulysses_all_to_all from xtuner.v1.utils import get_logger -from ...ops.gated_deltanet.causal_conv1d import causal_conv1d_fn -from ...ops.gated_deltanet.chunk_gated_delta_rule import chunk_gated_delta_rule +from ...ops.gated_deltanet import get_causal_conv1d_fn, get_chunk_gated_delta_rule_fn from ...ops.gated_deltanet.gen_seq_idx import gen_seq_idx from ...ops.gated_deltanet.rms_norm_gated import rms_norm_gated from ..linear import build_linear @@ -146,8 +145,10 @@ def __init__( A = torch.empty(self.num_v_heads).uniform_(0, 16) self.A_log = nn.Parameter(torch.log(A)) - self.causal_conv1d_fn = causal_conv1d_fn - self.chunk_gated_delta_rule = chunk_gated_delta_rule + # Resolved at build time so `XTUNER_HF_IMPL` selects the HF-native fla call patterns + # (same convention as `get_attn_impl_fn`). + self.causal_conv1d_fn = get_causal_conv1d_fn() + self.chunk_gated_delta_rule = get_chunk_gated_delta_rule_fn() assert FusedRMSNormGated is not None, ( "FusedRMSNormGated is not available. Please install fla to use GatedDeltaNet by `pip install flash-linear-attention`." ) diff --git a/xtuner/v1/module/attention/mha.py b/xtuner/v1/module/attention/mha.py index 48a0501b4..707e97805 100644 --- a/xtuner/v1/module/attention/mha.py +++ b/xtuner/v1/module/attention/mha.py @@ -15,7 +15,7 @@ from xtuner.v1.data_proto import SequenceContext from xtuner.v1.float8.config import Float8Config from xtuner.v1.module.rope import RopeScalingConfig -from xtuner.v1.ops import AttnOpOutputs, attn_impl_mapping, flash_attn_varlen_func, get_apply_rotary_emb +from xtuner.v1.ops import AttnOpOutputs, flash_attn_varlen_func, get_apply_rotary_emb, get_attn_impl_fn from xtuner.v1.ops.comm.all_to_all import ulysses_all_to_all from xtuner.v1.utils import XTUNER_DETERMINISTIC, get_device, get_logger, log_rank0 @@ -201,7 +201,7 @@ def __init__( ) self.apply_rotary_emb = get_apply_rotary_emb(fope_sep_head, enable_partial_rotary=enable_partial_rotary) # type: ignore - self.attn_impl_func: Callable[..., AttnOpOutputs] = attn_impl_mapping[attn_impl] # type: ignore[assignment] + self.attn_impl_func: Callable[..., AttnOpOutputs] = get_attn_impl_fn(attn_impl) # type: ignore[assignment] def prefilling( self, diff --git a/xtuner/v1/ops/__init__.py b/xtuner/v1/ops/__init__.py index 4a8d6907e..1e3caeb8e 100644 --- a/xtuner/v1/ops/__init__.py +++ b/xtuner/v1/ops/__init__.py @@ -4,7 +4,7 @@ __all__ = ["all_to_all_single_autograd", "ulysses_all_to_all"] from .act_fn import get_act_fn -from .attn_imp import AttnOpOutputs, attn_impl_mapping +from .attn_imp import AttnOpOutputs, attn_impl_mapping, get_attn_impl_fn from .flash_attn import flash_attn_varlen_func from .moe import group_gemm, permute, unpermute from .rms_norm import rms_norm, zero_centered_rms_norm diff --git a/xtuner/v1/ops/attn_imp.py b/xtuner/v1/ops/attn_imp.py index c37253f4f..a6712a695 100644 --- a/xtuner/v1/ops/attn_imp.py +++ b/xtuner/v1/ops/attn_imp.py @@ -1,3 +1,4 @@ +import os import traceback from functools import lru_cache from typing import TypedDict @@ -97,6 +98,14 @@ def _create_grouped_causal_mask(document_ids): return torch.where(final_mask, 0.0, float("-inf")) +def _create_grouped_full_mask(document_ids): + # Non-causal block-diagonal mask: every position attends to all positions in the same + # document/image (no causal triangle). Used by bidirectional attention (e.g. vision towers). + doc_matrix = document_ids.unsqueeze(2) == document_ids.unsqueeze(1) + + return torch.where(doc_matrix, 0.0, float("-inf")) + + def _create_windowed_grouped_causal_mask(document_ids, window_size): _, seq_len = document_ids.shape @@ -133,7 +142,7 @@ def mask_mod(b, h, q_idx, kv_idx): def eager_attention( - q, k, v, cu_seqlens_q, softmax_scale, window_size=(-1, -1), dropout_p=0.0, s_aux=None, **kwargs + q, k, v, cu_seqlens_q, softmax_scale, window_size=(-1, -1), dropout_p=0.0, s_aux=None, causal=True, **kwargs ) -> AttnOpOutputs: # TODO(HHA): Currently, the mask is recalculated each time, which is quite time-consuming. # It should be refactored to be calculated only once. @@ -148,11 +157,14 @@ def eager_attention( attn_weights = torch.matmul(q, k.transpose(2, 3)) * softmax_scale # type: ignore batch_document_ids = _get_document_ids_from_seq_lens(cu_seqlens_q) - if window_size == (-1, -1): - # Generate casual mask, the lower left corner is 0, and the other positions are -inf + if window_size != (-1, -1): + attention_mask = _create_windowed_grouped_causal_mask(batch_document_ids, window_size[0]) # type: ignore + elif causal: + # Causal block-diagonal mask: attend within the document, lower triangle only. attention_mask = _create_grouped_causal_mask(batch_document_ids) else: - attention_mask = _create_windowed_grouped_causal_mask(batch_document_ids, window_size[0]) # type: ignore + # Non-causal block-diagonal mask (e.g. bidirectional vision attention). + attention_mask = _create_grouped_full_mask(batch_document_ids) attention_mask = attention_mask[None].to(attn_weights.dtype) # 1,1,seq,seq causal_mask = attention_mask[:, :, :, : k.shape[-2]] @@ -170,7 +182,9 @@ def eager_attention( scores = probs[..., :-1] # we drop the sink here attn_logits = combined_logits.detach() else: - scores = torch.softmax(attn_weights, dim=-1, dtype=attn_weights.dtype) + # Upcast softmax to fp32 (matching HF's eager_attention_forward). Even when the forward + # values round identically in bf16, the fp32 softmax is what makes the *backward* bitwise. + scores = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(attn_weights.dtype) attn_logits = attn_weights.detach() attn_scores = nn.functional.dropout(scores, p=dropout_p, training=True) @@ -258,3 +272,25 @@ def flash_attention(q, k, v, window_size=(-1, -1), s_aux=None, **kwargs) -> Attn "flash_attention": flash_attention, "flex_attention": flex_attention, } + + +def get_attn_impl_fn(attn_impl: str): + """Resolve the attention op implementation for a configured ``attn_impl``. + + When ``XTUNER_HF_IMPL`` is set, the eager implementation is forced regardless + of the configured backend. XTuner's ``eager_attention`` matches HuggingFace's + ``eager_attention_forward`` bitwise (same fp32 softmax and dense causal mask), + which is what allows decoder layers to be aligned against transformers; the + fused flash/flex kernels are numerically close but not bitwise equal. + + Args: + attn_impl (str): The configured backend key in ``attn_impl_mapping``. + + Returns: + Callable[..., AttnOpOutputs]: The selected attention op. + """ + # Read the env var live (rather than a cached constant) so tests can toggle the + # HF-parity path per model instance within a single process. + if os.getenv("XTUNER_HF_IMPL") == "true": + return eager_attention + return attn_impl_mapping[attn_impl] diff --git a/xtuner/v1/ops/gated_deltanet/__init__.py b/xtuner/v1/ops/gated_deltanet/__init__.py new file mode 100644 index 000000000..81e8c2665 --- /dev/null +++ b/xtuner/v1/ops/gated_deltanet/__init__.py @@ -0,0 +1,54 @@ +"""GatedDeltaNet op-level dispatchers. + +`XTUNER_HF_IMPL` controls which implementations XTuner's `GatedDeltaNet` module uses, +mirroring how `xtuner/v1/ops/attn_imp.py::get_attn_impl_fn` and the rms_norm selector +switch between fast / fused paths and HF-exact paths. Under `XTUNER_HF_IMPL=true`: + +* `chunk_gated_delta_rule` is the canonical `fla.ops.gated_delta_rule.chunk_gated_delta_rule` + (same callable HF's `Qwen3_5GatedDeltaNet` uses), bypassing XTuner's + `torch.library.custom_op` wrap. +* `causal_conv1d_fn` is the high-level `causal_conv1d.causal_conv1d_fn` adapted to XTuner's + channel-last call site: the adapter transposes to channel-first, calls the package wrapper + with ``seq_idx=None`` (HF's non-packed convention), and transposes back. XTuner's own wrap + binds the channel-last/seq_idx convention together via an internal transpose, which gives a + different backward op graph from HF's call pattern even though forward is bitwise. + +These switches are only meant for the bitwise-parity tests. Production / training stays on the +XTuner path (compile-friendly custom_op wraps + seq_idx-aware kernel dispatch). +""" + +import os + +from causal_conv1d import causal_conv1d_fn as _hf_causal_conv1d_fn +from fla.ops.gated_delta_rule import chunk_gated_delta_rule as _hf_chunk_gated_delta_rule + +from .causal_conv1d import causal_conv1d_fn as _xtuner_causal_conv1d_fn +from .chunk_gated_delta_rule import chunk_gated_delta_rule as _xtuner_chunk_gated_delta_rule + + +_TRUTHY = {"true", "1", "yes", "on"} + + +def _hf_impl_enabled() -> bool: + return os.getenv("XTUNER_HF_IMPL", "").strip().lower() in _TRUTHY + + +def _hf_causal_conv1d_adapter(x, weight, bias, activation, seq_idx): + # XTuner's GatedDeltaNet supplies `x` in channel-last (``(batch, seq, dim)``) and always + # passes a `seq_idx` tensor; HF's call site uses channel-first and `seq_idx=None` for + # non-packed batches. Adapt: transpose, drop seq_idx, transpose back. + x_cf = x.transpose(1, 2) + out = _hf_causal_conv1d_fn(x=x_cf, weight=weight, bias=bias, activation=activation, seq_idx=None) + return out.transpose(1, 2) + + +def get_chunk_gated_delta_rule_fn(): + if _hf_impl_enabled(): + return _hf_chunk_gated_delta_rule + return _xtuner_chunk_gated_delta_rule + + +def get_causal_conv1d_fn(): + if _hf_impl_enabled(): + return _hf_causal_conv1d_adapter + return _xtuner_causal_conv1d_fn diff --git a/xtuner/v1/ops/rms_norm/__init__.py b/xtuner/v1/ops/rms_norm/__init__.py index c89cd8596..26534e744 100644 --- a/xtuner/v1/ops/rms_norm/__init__.py +++ b/xtuner/v1/ops/rms_norm/__init__.py @@ -39,8 +39,10 @@ def get_rms_norm_fn() -> RMSNormProtocol: device = get_device() if device in ["cpu", "cuda"]: + # XTUNER_HF_IMPL forces the native torch path so the result matches HF bitwise. + use_triton = os.getenv("XTUNER_USE_NATIVE_RMSNORM", "1") == "0" and os.getenv("XTUNER_HF_IMPL") != "true" # TODO: control triton rmsnorm by model config rather than env var - if os.getenv("XTUNER_USE_NATIVE_RMSNORM", "1") == "0" and device == "cuda": + if use_triton and device == "cuda": return _triton_rms_norm else: return native_rms_norm diff --git a/xtuner/v1/utils/__init__.py b/xtuner/v1/utils/__init__.py index 915e2a0c7..96f4defdd 100644 --- a/xtuner/v1/utils/__init__.py +++ b/xtuner/v1/utils/__init__.py @@ -10,6 +10,7 @@ from .logger import get_logger, log_format, log_rank0 from .misc import ( XTUNER_DETERMINISTIC, + XTUNER_HF_IMPL, FunctionEnum, SharedMemory, clean_param_name, @@ -51,6 +52,7 @@ "ParallelConfigException", "log_format", "XTUNER_DETERMINISTIC", + "XTUNER_HF_IMPL", "Config", "record_git_info", "is_hf_model_path", diff --git a/xtuner/v1/utils/misc.py b/xtuner/v1/utils/misc.py index d03e28f5f..9a8639537 100644 --- a/xtuner/v1/utils/misc.py +++ b/xtuner/v1/utils/misc.py @@ -25,6 +25,11 @@ logger = get_logger() XTUNER_DETERMINISTIC = os.getenv("XTUNER_DETERMINISTIC") == "true" +# When set, ops select the implementation that matches HuggingFace transformers +# bitwise (e.g. eager attention with fp32 softmax, native rms_norm) instead of the +# fused/kernel fast paths. Used to align XTuner decoder layers against HF for +# numerical-parity testing; not for production training. +XTUNER_HF_IMPL = os.getenv("XTUNER_HF_IMPL") == "true" def set_deterministic(): From 73af707a705eae7f55a2e996303a1c6a6db53ae0 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 26 May 2026 21:02:34 +0000 Subject: [PATCH 4/8] [Test] Add Qwen3.5-4B dense parity and save_hf tests Adds decoder-layer bitwise parity (all linear + full layers and final-norm hidden match HF bitwise under XTUNER_HF_IMPL), model forward parity on the default flash path (within 1e-2), and a save_hf round-trip over non-mtp keys (MTP is deferred). Reads the checkpoint from QWEN3_5_DENSE_4B_PATH. --- tests/model/test_qwen3_5_dense.py | 475 ++++++++++++++++++++++++++++++ xtuner/_testing/testcase.py | 134 +++++++++ xtuner/_testing/utils.py | 8 +- 3 files changed, 616 insertions(+), 1 deletion(-) create mode 100644 tests/model/test_qwen3_5_dense.py diff --git a/tests/model/test_qwen3_5_dense.py b/tests/model/test_qwen3_5_dense.py new file mode 100644 index 000000000..6c25cf550 --- /dev/null +++ b/tests/model/test_qwen3_5_dense.py @@ -0,0 +1,475 @@ +import json +import os +import tempfile +import unittest +from pathlib import Path + +import parametrize +import torch +import torch.distributed as dist +from packaging.version import Version +from safetensors import safe_open +from transformers import __version__ as transformers_version + +from xtuner.v1.config import FSDPConfig +from xtuner.v1.data_proto import SequenceContext +from xtuner.v1.model import Qwen3_5_VLDense4BConfig +from xtuner.v1.model.compose.qwen3_vl.modeling_vision import init_world_mesh +from xtuner.v1.model.dense.qwen3_5_text import Qwen3_5_VLTextDense4BConfig +from xtuner._testing import DeterministicDDPTestCase + + +# Qwen3.5-4B (dense, hybrid linear + full attention VLM) +QWEN3_5_DENSE_4B_PATH = os.environ["QWEN3_5_DENSE_4B_PATH"] + + +@unittest.skipIf( + Version(transformers_version) < Version("5.9.0"), + f"transformers >= 5.9.0 is required, but got {transformers_version}", +) +class TestQwen3_5_VLDense(DeterministicDDPTestCase): + @parametrize.parametrize("device,layer_idx", [("cuda", 3), ("cuda", 0)]) + def test_decoder_layer_bitwise_parity(self, device, layer_idx): + # One decoder layer (full=3 / linear=0) -> final norm -> lm_head -> CE loss -> backward. + # Under XTUNER_HF_IMPL (eager) the layer output, the loss, and the input gradient dL/dx must all + # match HF bitwise — i.e. forward AND backward parity at the layer level. Only this layer + norm + # + lm_head are materialized/loaded on GPU (XTuner builds the rest on meta; HF builds a standalone + # layer), so this runs at any model scale, including models too large to forward end-to-end. + import torch.nn.functional as F + from transformers import Qwen3_5Config, Qwen3_5ForConditionalGeneration + + from xtuner.v1.utils import HFCheckpointLoader + + self.create_pg(device) + with self.hf_impl(): + loader = HFCheckpointLoader(QWEN3_5_DENSE_4B_PATH) + + # ---- XTuner: build the full tower on meta, materialize only layer/norm/lm_head ---- + with torch.device("meta"): + cfg = Qwen3_5_VLTextDense4BConfig(compile_cfg=False) + model = cfg.build() + layer_type = cfg.layers_type[layer_idx] + is_linear = layer_type == "linear_attention" + xt_layer = model.layers[str(layer_idx)] + self.materialize_submodule(model, xt_layer, loader) + self.materialize_submodule(model, model.norm, loader) + self.materialize_submodule(model, model.lm_head, loader) + model.rotary_emb.to("cuda") # built on CPU with real buffers even under meta; no checkpoint weights + + # ---- HF: meta-build the full compose, then materialize only the layer/norm/lm_head we + # touch. Mirrors the XTuner side: `materialize_submodule` recovers each submodule's + # checkpoint prefix from its `named_modules` path, and honors HF's + # `_tied_weights_keys` so `lm_head.weight` is loaded from `embed_tokens.weight`. + hf_cfg = Qwen3_5Config.from_pretrained(QWEN3_5_DENSE_4B_PATH) + hf_cfg._attn_implementation = "eager" + hf_cfg.text_config._attn_implementation = "eager" + hf_cfg.vision_config._attn_implementation = "eager" + with torch.device("meta"): + hf_compose = Qwen3_5ForConditionalGeneration(hf_cfg).eval() + hf_layer = hf_compose.model.language_model.layers[layer_idx] + hf_norm = hf_compose.model.language_model.norm + hf_lm_head = hf_compose.lm_head + self.materialize_submodule(hf_compose, hf_layer, loader) + self.materialize_submodule(hf_compose, hf_norm, loader) + self.materialize_submodule(hf_compose, hf_lm_head, loader) + + seq = 16 + ids = torch.randint(0, 1000, (1, seq), device="cuda") + seq_ctx = SequenceContext.from_input_ids(input_ids=(ids,)) + seq_ctx.to("cuda") + cos, sin = model.rotary_emb( + torch.empty(1, seq, cfg.hidden_size, device="cuda", dtype=torch.bfloat16), seq_ctx.position_ids + ) + labels = torch.randint(0, cfg.vocab_size, (seq,), device="cuda") + base = torch.randn(1, seq, cfg.hidden_size, device="cuda", dtype=torch.bfloat16) + + # full attention needs a causal mask; the linear (GatedDeltaNet) layer ignores it. + attn_mask = ( + None + if is_linear + else torch.triu( + torch.full((seq, seq), float("-inf"), device="cuda", dtype=torch.bfloat16), diagonal=1 + )[None, None] + ) + + x_hf = base.clone().requires_grad_(True) + o_hf = hf_layer(x_hf, position_embeddings=(cos, sin), attention_mask=attn_mask) + o_hf = o_hf[0] if isinstance(o_hf, tuple) else o_hf + loss_hf = F.cross_entropy(hf_lm_head(hf_norm(o_hf)).reshape(-1, cfg.vocab_size), labels) + loss_hf.backward() + + x_xt = base.clone().requires_grad_(True) + o_xt = xt_layer(x_xt, (cos, sin), seq_ctx) + loss_xt = F.cross_entropy(F.linear(model.norm(o_xt), model.lm_head.weight).reshape(-1, cfg.vocab_size), labels) + loss_xt.backward() + + # forward (layer output), loss, and backward (dL/dx) must all be bitwise. + out_diff = (o_hf.float() - o_xt.float().reshape(o_hf.shape)).abs().max().item() + loss_diff = (loss_hf.float() - loss_xt.float()).abs().item() + grad_diff = (x_hf.grad.float() - x_xt.grad.float()).abs().max().item() # type: ignore[union-attr] + self.assertEqual(out_diff, 0.0, f"layer {layer_idx} [{layer_type}] output not bitwise: max diff {out_diff}") + self.assertEqual(loss_diff, 0.0, f"layer {layer_idx} [{layer_type}] loss not bitwise: {loss_diff}") + self.assertEqual(grad_diff, 0.0, f"layer {layer_idx} [{layer_type}] dL/dx not bitwise: max diff {grad_diff}") + dist.barrier() + + @parametrize.parametrize("device", [("cuda",)]) + def test_vision_tower_bitwise_parity(self, device): + # Vision tower (patch_embed + blocks + merger) bitwise vs HF, eager on both sides. Loads ONLY the + # vision tower on each side (standalone; weights are the checkpoint's `model.visual.*`), not the + # full VLM, so it stays memory-light for large models. Vision attention is non-causal: under + # XTUNER_HF_IMPL the vision tower routes to eager_attention with causal=False. Pos-embed patch + # is required. + self.create_pg(device) + self._patch_fast_pos_embed_interpolate() + from transformers import Qwen3_5Config, Qwen3_5ForConditionalGeneration + + from xtuner.v1.utils import HFCheckpointLoader + + raw_data = { + "id": 3, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "tests/resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}, + }, + {"type": "text", "text": "\n描述图片"}, + ], + }, + {"role": "assistant", "content": "狗。"}, + ], + } + inputs = self._tokenize_qwen3vl(raw_data) + pixel_values = inputs["pixel_values"] + image_grid_thw = inputs["image_grid_thw"] + + with self.hf_impl(): + loader = HFCheckpointLoader(QWEN3_5_DENSE_4B_PATH) + + # ---- HF: meta-build the full compose, then materialize only the vision tower. The + # `Qwen3_5VisionRotaryEmbedding.inv_freq` buffer (persistent=False, init-computed) is + # rebuilt inside `materialize_submodule` from its stored `dim`/`theta`. ---- + hf_cfg = Qwen3_5Config.from_pretrained(QWEN3_5_DENSE_4B_PATH) + hf_cfg._attn_implementation = "eager" + hf_cfg.text_config._attn_implementation = "eager" + hf_cfg.vision_config._attn_implementation = "eager" + with torch.device("meta"): + hf_compose = Qwen3_5ForConditionalGeneration(hf_cfg).eval() + hf_vision = hf_compose.model.visual + self.materialize_submodule(hf_compose, hf_vision, loader) + hf_pv = pixel_values.clone().requires_grad_(True) + hf_merged = hf_vision(hf_pv, grid_thw=image_grid_thw).pooler_output + hf_merged.sum().backward() + + # ---- XTuner: same flow as HF — meta-build the compose, then materialize only + # vision_tower + multi_modal_projector. XTuner splits HF's `visual` into vision_tower + # (patches -> merged hidden) + projector (merger MLP -> text dim); HF's pooler_output is + # post-merger so both are needed. ---- + with torch.device("meta"): + xt_compose = Qwen3_5_VLDense4BConfig(compile_cfg=False).build() + xt_vision = xt_compose.vision_tower + xt_projector = xt_compose.multi_modal_projector + self.materialize_submodule(xt_compose, xt_vision, loader) + self.materialize_submodule(xt_compose, xt_projector, loader) + self.assertEqual(xt_vision.blocks[0].attn.attn_impl_func.__name__, "eager_attention") + xt_pv = pixel_values.clone().requires_grad_(True) + xt_merged, xt_deepstack = xt_vision(xt_pv, image_grid_thw) + xt_merged, _ = xt_projector(xt_merged, xt_deepstack) + xt_merged.sum().backward() + + out_diff = (hf_merged.float() - xt_merged.float().reshape(hf_merged.shape)).abs().max().item() + grad_diff = (hf_pv.grad.float() - xt_pv.grad.float()).abs().max().item() # type: ignore[union-attr] + self.assertEqual(out_diff, 0.0, f"vision tower output not bitwise: max diff {out_diff}") + self.assertEqual(grad_diff, 0.0, f"vision tower dL/d(pixel_values) not bitwise: max diff {grad_diff}") + dist.barrier() + + @parametrize.parametrize("device", [("cuda",)]) + def test_vl_forward_parity(self, device): + # Whole-model (compose VLM) forward + backward parity vs HF on an image prompt — the VLM is the + # real model, so this is the end-to-end integration check across vision + projector + text. + # Eager both sides (XTUNER_HF_IMPL + HF eager): forward logits + loss are bitwise; the backward + # dL/d(pixel_values) — which backprops through the whole text tower + projector + vision — is + # checked at a tolerance (bf16 grad accumulates over depth; the bitwise backward guarantee is + # test_decoder_layer_bitwise_parity). Image only: the image case already runs the text tower, so + # a text-only case would just re-check the text forward. + import torch.nn.functional as F + + self.create_pg(device) + self._patch_fast_pos_embed_interpolate() + from transformers import Qwen3_5ForConditionalGeneration + + raw_data = { + "id": 3, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "tests/resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}, + }, + {"type": "text", "text": "\n描述图片"}, + ], + }, + {"role": "assistant", "content": "狗是棕色的。"}, + ], + } + inputs = self._tokenize_qwen3vl(raw_data) + input_ids, image_grid_thw, position_ids = inputs["input_ids"], inputs["image_grid_thw"], inputs["position_ids"] + base_pixels = inputs["pixel_values"] + + with self.hf_impl(): + hf = Qwen3_5ForConditionalGeneration.from_pretrained( + QWEN3_5_DENSE_4B_PATH, dtype=torch.bfloat16, attn_implementation="eager", device_map="cuda" + ).eval() + pv_hf = base_pixels.clone().requires_grad_(True) + hf_logits = hf( + input_ids=input_ids, + pixel_values=pv_hf, + image_grid_thw=image_grid_thw, + position_ids=position_ids, + use_cache=False, + ).logits + labels = torch.randint(0, hf_logits.size(-1), (hf_logits.size(1),), device="cuda") + loss_hf = F.cross_entropy(hf_logits.reshape(-1, hf_logits.size(-1)), labels) + loss_hf.backward() + del hf + torch.cuda.empty_cache() + + with torch.device("meta"): + model_cfg = Qwen3_5_VLDense4BConfig(compile_cfg=False) # text + vision -> eager via env + model = model_cfg.build()._to_device_dtype(dtype=torch.bfloat16, skip_buffers_dtype=True) + model.from_hf(QWEN3_5_DENSE_4B_PATH) + model.eval() + + pv_xt = base_pixels.clone().requires_grad_(True) + seq_ctx = SequenceContext.from_input_ids(input_ids=(input_ids,)) + seq_ctx.to("cuda") + seq_ctx.image_grid_thw = image_grid_thw + seq_ctx.pixel_values = pv_xt + if position_ids is not None: + seq_ctx.position_ids = position_ids + xt_logits = model(seq_ctx=seq_ctx, loss_ctx=None)["logits"] + loss_xt = F.cross_entropy(xt_logits.reshape(-1, xt_logits.size(-1)), labels) + loss_xt.backward() + + logit_diff = (hf_logits.float() - xt_logits.float().reshape(hf_logits.shape)).abs().max().item() + loss_diff = (loss_hf.float() - loss_xt.float()).abs().item() + self.assertEqual(logit_diff, 0.0, f"VL logits not bitwise-equal: max diff {logit_diff}") + self.assertEqual(loss_diff, 0.0, f"VL loss not bitwise-equal: {loss_diff}") + # Backward (dL/d(pixel_values)) backprops through CE -> lm_head -> 32-layer LM -> + # masked_scatter -> projector -> 24-layer ViT. With `XTUNER_HF_IMPL` routing + # `GatedDeltaNet` to fla's high-level `chunk_gated_delta_rule` + HF-style + # `causal_conv1d_fn` (no XTuner custom_op wrap, no seq_idx-driven layout switch) and + # the `XTUNER_DETERMINISTIC` Triton autotune pin (`tests/conftest.py`), every backward + # op matches HF byte-for-byte — full e2e backward is bitwise. + grad_diff = (pv_hf.grad.float() - pv_xt.grad.float()).abs().max().item() # type: ignore[union-attr] + self.assertEqual(grad_diff, 0.0, f"VL dL/d(pixel_values) not bitwise: max diff {grad_diff}") + dist.barrier() + + @parametrize.parametrize("device", [("cuda",)]) + def test_model_forward_bitwise_reduced_layers(self, device): + # Whole-model bitwise parity that runs the real `compose.forward` / `Dense.forward` orchestration + # — embed_tokens, rotary_emb call site, the layer loop, `_prepare_llm_inputs` image-embed + # injection, final norm + lm_head, return packing — none of which are exercised by + # `test_decoder_layer_bitwise_parity` (that test inlines them by hand). To keep this runnable at + # any model scale we truncate `text_config.num_hidden_layers` to N: per-layer numerical + # correctness is already owned by the decoder-layer test, so this test does not need every + # layer present — it only needs the forward orchestration to run end-to-end. N=4 covers both + # `linear_attention` (idx 0-2) and `full_attention` (idx 3) under Qwen3.5's `(i+1)%4==0` + # pattern. Vision tower / projector are not truncated. `XTUNER_HF_IMPL` + `XTUNER_DETERMINISTIC` + # + HF eager → logits / loss / `dL/d(pixel_values)` all bitwise. + import torch.nn.functional as F + + N_TEXT_LAYERS = 4 + + self.create_pg(device) + self._patch_fast_pos_embed_interpolate() + from transformers import Qwen3_5Config, Qwen3_5ForConditionalGeneration + + raw_data = { + "id": 3, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "tests/resource/mscoco_dog_000000319154.jpg", "image_wh": [375, 500]}, + }, + {"type": "text", "text": "\n描述图片"}, + ], + }, + {"role": "assistant", "content": "狗是棕色的。"}, + ], + } + inputs = self._tokenize_qwen3vl(raw_data) + input_ids, image_grid_thw, position_ids = inputs["input_ids"], inputs["image_grid_thw"], inputs["position_ids"] + base_pixels = inputs["pixel_values"] + + with self.hf_impl(): + # ---- HF: truncate `text_config.num_hidden_layers` then `from_pretrained(config=...)`. HF + # silently ignores the unused layer 4..31 ckpt keys ("UNEXPECTED" warning) and skips the + # meta+materialize song-and-dance — keeps the test's HF side close to how a user would + # actually downscale a real model, and avoids the subtle lazy-init / tie_weights gotchas + # that meta+materialize would otherwise need to special-case. + hf_cfg = Qwen3_5Config.from_pretrained(QWEN3_5_DENSE_4B_PATH) + hf_cfg._attn_implementation = "eager" + hf_cfg.text_config._attn_implementation = "eager" + hf_cfg.vision_config._attn_implementation = "eager" + hf_cfg.text_config.num_hidden_layers = N_TEXT_LAYERS + hf_cfg.text_config.layer_types = hf_cfg.text_config.layer_types[:N_TEXT_LAYERS] + hf = Qwen3_5ForConditionalGeneration.from_pretrained( + QWEN3_5_DENSE_4B_PATH, config=hf_cfg, dtype=torch.bfloat16, device_map="cuda" + ).eval() + + pv_hf = base_pixels.clone().requires_grad_(True) + hf_logits = hf( + input_ids=input_ids, + pixel_values=pv_hf, + image_grid_thw=image_grid_thw, + position_ids=position_ids, + use_cache=False, + ).logits + labels = torch.randint(0, hf_logits.size(-1), (hf_logits.size(1),), device="cuda") + loss_hf = F.cross_entropy(hf_logits.reshape(-1, hf_logits.size(-1)), labels) + loss_hf.backward() + del hf + torch.cuda.empty_cache() + + # ---- XTuner: same truncation. `layers_type` is a computed_field over `num_hidden_layers`, + # so updating one field is enough. compose.from_hf is strict=True by default; since the + # truncated model only expects layer 0..N-1 keys (and those exist in the ckpt), the + # missing-set is empty and load succeeds. The extra layer N..31 ckpt keys go unread. + text_cfg = Qwen3_5_VLTextDense4BConfig().model_copy(update={"num_hidden_layers": N_TEXT_LAYERS}) + compose_cfg = Qwen3_5_VLDense4BConfig(text_config=text_cfg, compile_cfg=False) + with torch.device("meta"): + model = compose_cfg.build()._to_device_dtype(dtype=torch.bfloat16, skip_buffers_dtype=True) + model.from_hf(QWEN3_5_DENSE_4B_PATH) + model.eval() + + pv_xt = base_pixels.clone().requires_grad_(True) + seq_ctx = SequenceContext.from_input_ids(input_ids=(input_ids,)) + seq_ctx.to("cuda") + seq_ctx.image_grid_thw = image_grid_thw + seq_ctx.pixel_values = pv_xt + if position_ids is not None: + seq_ctx.position_ids = position_ids + xt_logits = model(seq_ctx=seq_ctx, loss_ctx=None)["logits"] + loss_xt = F.cross_entropy(xt_logits.reshape(-1, xt_logits.size(-1)), labels) + loss_xt.backward() + + logit_diff = (hf_logits.float() - xt_logits.float().reshape(hf_logits.shape)).abs().max().item() + loss_diff = (loss_hf.float() - loss_xt.float()).abs().item() + grad_diff = (pv_hf.grad.float() - pv_xt.grad.float()).abs().max().item() # type: ignore[union-attr] + self.assertEqual(logit_diff, 0.0, f"reduced-layer VL logits not bitwise: max diff {logit_diff}") + self.assertEqual(loss_diff, 0.0, f"reduced-layer VL loss not bitwise: {loss_diff}") + self.assertEqual(grad_diff, 0.0, f"reduced-layer VL dL/d(pixel_values) not bitwise: max diff {grad_diff}") + dist.barrier() + + @parametrize.parametrize("device", [("cuda",)]) + def test_save_hf_round_trip(self, device): + # MTP is deferred for the dense port, so the 15 ``mtp.*`` checkpoint keys are + # neither loaded nor re-saved. The round-trip is therefore asserted over the + # non-``mtp.*`` keys only, and the saved index must contain no ``mtp.*`` key. + self.create_pg(device) + + with torch.device("meta"): + model_cfg = Qwen3_5_VLDense4BConfig(compile_cfg=False) + model = model_cfg.build()._to_device_dtype(dtype=torch.bfloat16, skip_buffers_dtype=True) + + fsdp_config = FSDPConfig(cpu_offload=False) + fsdp_mesh = init_world_mesh() + model.vision_tower.fsdp_mesh = fsdp_mesh + model.vision_tower.fsdp_config = fsdp_config + model.fully_shard(fsdp_config=fsdp_config) + + with tempfile.TemporaryDirectory() as tmpdir: + syncdir = [tmpdir] + dist.broadcast_object_list(syncdir, src=0) + tmpdir = Path(syncdir[0]) + model.from_hf(QWEN3_5_DENSE_4B_PATH) + model.save_hf(tmpdir) + + origin_hf_path = Path(QWEN3_5_DENSE_4B_PATH) + origin_index_path = origin_hf_path / "model.safetensors.index.json" + saved_index_path = tmpdir / "model.safetensors.index.json" + + if dist.get_rank() == 0: + with open(origin_index_path, "r") as f: + origin_index = json.load(f) + with open(saved_index_path, "r") as f: + saved_index = json.load(f) + + cache_save_fh: dict = {} + + for key in origin_index["weight_map"].keys(): + if key.startswith("mtp."): + self.assertNotIn(key, saved_index["weight_map"]) + continue + + origin_safetensor_name = origin_index["weight_map"][key] + saved_safetensor_name = saved_index["weight_map"][key] + + origin_sf_fh_name = str(origin_hf_path / origin_safetensor_name) + saved_sf_fh_name = str(tmpdir / saved_safetensor_name) + + if origin_sf_fh_name not in cache_save_fh: + cache_save_fh[origin_sf_fh_name] = safe_open(origin_sf_fh_name, framework="pt") + if saved_sf_fh_name not in cache_save_fh: + cache_save_fh[saved_sf_fh_name] = safe_open(saved_sf_fh_name, framework="pt") + + origin_tensor = cache_save_fh[origin_sf_fh_name].get_tensor(key) + saved_tensor = cache_save_fh[saved_sf_fh_name].get_tensor(key) + + self.assertTrue(torch.equal(origin_tensor, saved_tensor), f"Tensor mismatch for key: {key}") + + mtp_keys = [key for key in saved_index["weight_map"].keys() if key.startswith("mtp.")] + self.assertListEqual(mtp_keys, []) + + safetensor_keys: list[str] = [] + for safetensor_path in tmpdir.glob("*.safetensors"): + fh = safe_open(str(safetensor_path), framework="pt") + safetensor_keys.extend(fh.keys()) + safetensor_keys.sort() + model_index_keys = list(saved_index["weight_map"].keys()) + model_index_keys.sort() + self.assertListEqual(safetensor_keys, model_index_keys) + + dist.barrier() + + def _patch_fast_pos_embed_interpolate(self) -> None: + # HF's fast_pos_embed_interpolate returns fp32; the reused XTuner vision forward adds + # pos_embeds without a cast, so cast the result back to the pos_embed dtype here to + # avoid an fp32/bf16 LayerNorm mismatch. + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5VisionModel + + from xtuner.v1.model.compose.qwen3_vl.modeling_vision import Qwen3VLVisionModel + + def _interp(self, grid_thw): + return Qwen3_5VisionModel.fast_pos_embed_interpolate(self, grid_thw).to(self.pos_embed.weight.dtype) + + Qwen3VLVisionModel.fast_pos_embed_interpolate = _interp + + def _tokenize_qwen3vl(self, raw_data) -> dict: + # Tokenize one Qwen3.5-VL image SFT sample and return device-resident model inputs. + from transformers import AutoTokenizer + + from xtuner.v1.datasets import Qwen3VLTokenizeFnConfig + + tokenizer = AutoTokenizer.from_pretrained(QWEN3_5_DENSE_4B_PATH) + tokenize_fn = Qwen3VLTokenizeFnConfig(processor_path=QWEN3_5_DENSE_4B_PATH, add_vision_id=True).build(tokenizer) + tokenized = tokenize_fn(raw_data) + return { + "input_ids": torch.tensor(tokenized["input_ids"])[None].cuda(), + "labels": torch.tensor(tokenized["labels"])[None].cuda(), + "pixel_values": tokenized["pixel_values"].cuda(), + "image_grid_thw": tokenized["image_grid_thw"].cuda(), + "position_ids": tokenized["position_ids"].cuda(), + } + + @property + def world_size(self) -> int: + return int(os.getenv("XTUNER_TEST_WORLD_SIZE", "1")) diff --git a/xtuner/_testing/testcase.py b/xtuner/_testing/testcase.py index 39077e7f3..3d27d8c54 100644 --- a/xtuner/_testing/testcase.py +++ b/xtuner/_testing/testcase.py @@ -4,6 +4,9 @@ import threading import sys import os +import re +import contextlib +import inspect import unittest import traceback from .utils import enable_full_determinism @@ -98,3 +101,134 @@ def create_pg(self, device): ret = super().create_pg(device) os.environ["LOCAL_RANK"] = str(dist.get_rank() % torch.cuda.device_count()) return ret + + # ------------------------------------------------------------------ # + # HuggingFace bitwise-parity helpers (shared by model parity tests) # + # ------------------------------------------------------------------ # + @contextlib.contextmanager + def hf_impl(self): + """Force XTuner ops onto the HF-exact eager path (``XTUNER_HF_IMPL``) for the duration of the + block, restoring the previous value on exit. Build models *inside* the block so their + attention modules pick the eager op at construction time.""" + prev = os.environ.get("XTUNER_HF_IMPL") + os.environ["XTUNER_HF_IMPL"] = "true" + try: + yield + finally: + if prev is None: + os.environ.pop("XTUNER_HF_IMPL", None) + else: + os.environ["XTUNER_HF_IMPL"] = prev + + @staticmethod + def load_params_from_hf(module, loader, key_for=None) -> None: + """Copy every parameter of ``module`` from a HF checkpoint. + + ``key_for`` selects each parameter's checkpoint key: + + * ``None`` (default): ``module`` must expose ``to_hf_key_list`` (any XTuner ``BaseModel``); + ``module.to_hf_key_list(name)[0]`` is used. Lets a standalone XTuner tower load itself + without the caller re-passing ``module`` in a lambda. + * ``str``: treated as a prefix; ``key = key_for + name``. Convenient for an HF module + whose checkpoint keys are ``""``. + * ``Callable[[str], str]``: arbitrary mapping (used internally to apply ``hf_key_mapping``). + + ``loader`` (``HFCheckpointLoader``) reads only the safetensors shard holding each key, so a + single layer can be loaded without materializing the full model.""" + if key_for is None: + if not hasattr(module, "to_hf_key_list"): + raise ValueError( + f"module of type {type(module).__name__} has no `to_hf_key_list`; pass `key_for=`." + ) + get_key = lambda n: module.to_hf_key_list(n)[0] + elif isinstance(key_for, str): + prefix = key_for + get_key = lambda n: prefix + n + else: + get_key = key_for + for name, p in module.named_parameters(): + key = get_key(name) + tensor = loader.load(key) + assert tensor is not None, f"checkpoint key not found: {key}" + p.data.copy_(tensor.to(device=p.device, dtype=p.dtype)) + + @staticmethod + def xtuner_ckpt_key(model, param_name: str) -> str: + """Resolve the HF checkpoint key for an XTuner parameter: ``to_hf_key_list`` plus the config's + ``hf_key_mapping`` (the remap normally applied inside ``_init_load_spec``).""" + key = model.to_hf_key_list(param_name)[0] + for pattern, repl in (model.config.hf_key_mapping or {}).items(): + if re.search(pattern, key): + return re.sub(pattern, repl, key) + return key + + def materialize_submodule(self, model, submodule, loader, dtype=torch.bfloat16) -> None: + """Materialize a single submodule of a meta-built model on CUDA and load its weights from + the checkpoint. Works uniformly for XTuner ``BaseModel`` and HF ``PreTrainedModel``: + + * The submodule's path inside ``model`` is recovered by identity match against + ``model.named_modules()``. + * For XTuner the path is run through ``to_hf_key_list`` + the config's ``hf_key_mapping``; + for HF the path is *already* the state_dict prefix so it is used directly, with HF's + ``_tied_weights_keys`` honored (e.g. ``lm_head.weight`` redirected to the canonical + ``model.language_model.embed_tokens.weight`` when ``lm_head`` is loaded standalone). + + Lets a single-layer test scale to any model size on both sides: build the parent model + under ``torch.device("meta")``, then materialize only the submodules you actually touch + (the tested layer, ``norm``, ``lm_head``).""" + path = next((name for name, mod in model.named_modules() if mod is submodule), None) + if path is None: + raise ValueError(f"submodule of type {type(submodule).__name__} not found in model") + prefix = f"{path}." if path else "" + submodule.to_empty(device="cuda") + submodule.to(dtype) + # Rebuild RoPE ``inv_freq`` buffers that ``to_empty`` left as garbage. Scope by + # ``inv_freq`` so we don't touch unrelated modules. Pull each rotary's ``__init__`` args + # generically via ``inspect.signature`` + the matching stored attrs (the convention HF and + # XTuner rotary modules follow), then re-instantiate on CPU and copy ``inv_freq`` back — + # the init formula lives in the rotary class itself, whichever variant it is. + for mod in submodule.modules(): + if not hasattr(mod, "inv_freq"): + continue + init_args: list | None = [] + for pname, param in inspect.signature(type(mod).__init__).parameters.items(): + if pname == "self": + continue + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + init_args = None + break + if hasattr(mod, pname): + init_args.append(getattr(mod, pname)) + elif param.default is not inspect.Parameter.empty: + init_args.append(param.default) + else: + # Required arg not recoverable from a stored attribute; skip. + init_args = None + break + if init_args is None: + continue + with torch.device("cpu"): + fresh = type(mod)(*init_args) + mod.inv_freq.data = fresh.inv_freq.data.to(device=mod.inv_freq.device, dtype=mod.inv_freq.dtype) + + tied = getattr(model, "_tied_weights_keys", None) or {} + is_xtuner = hasattr(model, "to_hf_key_list") + + def key_for_param(param_name: str) -> str: + full = prefix + param_name + # HF tied weights (e.g. lm_head.weight -> ...embed_tokens.weight) — the canonical key + # is the one present in the checkpoint. + if full in tied: + return tied[full] + # XTuner BaseModel sub-tower with its own `_hf_prefix` / `to_hf_key_list` (e.g. + # `vision_tower`, `multi_modal_projector`): use that directly — the deployment prefix is + # already baked in by the sub-tower, so no compose-level mapping is needed here. + if hasattr(submodule, "to_hf_key_list"): + return submodule.to_hf_key_list(param_name)[0] + if is_xtuner: + return self.xtuner_ckpt_key(model, full) + # HF: state_dict key is exactly the named_parameters path. + return full + + self.load_params_from_hf(submodule, loader, key_for=key_for_param) + diff --git a/xtuner/_testing/utils.py b/xtuner/_testing/utils.py index dbc1517a2..351a117c0 100644 --- a/xtuner/_testing/utils.py +++ b/xtuner/_testing/utils.py @@ -16,7 +16,13 @@ def enable_full_determinism(): os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" - # torch.use_deterministic_algorithms(True, warn_only=True) + # Force PyTorch to pick the deterministic implementation for any op that has one. Without this, + # per-layer tests still hit deterministic-by-default ops, but full-model backwards (CE/softmax + # backward reduces, masked_scatter backward, scatter_add gradient accumulation, etc.) can + # silently fall into non-deterministic kernels and produce drift on the order of grad magnitude + # (~1.0 on massive-activation channels). `warn_only` so ops without a deterministic impl warn + # instead of erroring. + torch.use_deterministic_algorithms(True, warn_only=True) torch.set_deterministic_debug_mode(0) From ae0df4b9d2130711206e954646e99c73c547206e Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 26 May 2026 22:54:40 +0000 Subject: [PATCH 5/8] [Fix] Select torch.compile fullgraph per layer type in Dense.fully_shard Hybrid dense models (e.g. Qwen3.5) mix gated-MHA full-attention layers with GatedDeltaNet linear-attention layers. GatedDeltaNet writes seq_ctx.seq_idx inside the activation-checkpoint region; compiling that layer with fullgraph=True turns the checkpoint into a HigherOrderOperator that rejects the side effect (torch._dynamo SideEffects). Pick fullgraph per layer type so linear layers compile with fullgraph=False (the write graph-breaks) while full-attention layers keep fullgraph=True. No-op for pure-MHA dense models (all full_attention). --- xtuner/v1/model/dense/dense.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index 247009c80..ef1ad4c7e 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -240,8 +240,13 @@ def fully_shard( ) # __class__ without self attribute + # Linear-attention (GatedDeltaNet) layers write ``seq_ctx.seq_idx`` inside the + # checkpoint region; compiling the wrapped layer with ``fullgraph=True`` turns the + # checkpoint into a HigherOrderOperator that rejects that side effect. Such layers are + # still compiled, but with ``fullgraph=False`` so the write can graph-break. if self.compile_cfg: - layer.forward = torch.compile(layer.forward, fullgraph=True) + fullgraph = self.config.layers_type[layer_idx] != "linear_attention" + layer.forward = torch.compile(layer.forward, fullgraph=fullgraph) self.layers[str(layer_idx)] = layer self._fully_shard( From ccbc7b9dc653a19d37b0ca7125bf18adf9a63c1e Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Tue, 26 May 2026 23:03:57 +0000 Subject: [PATCH 6/8] [Docs] Update claude code add hf model skills --- .claude/skills/add_hf_model/SKILL.md | 767 +++++++++++++++++++++++++++ 1 file changed, 767 insertions(+) create mode 100644 .claude/skills/add_hf_model/SKILL.md diff --git a/.claude/skills/add_hf_model/SKILL.md b/.claude/skills/add_hf_model/SKILL.md new file mode 100644 index 000000000..ab06691d6 --- /dev/null +++ b/.claude/skills/add_hf_model/SKILL.md @@ -0,0 +1,767 @@ +--- +name: add_hf_model +description: > + Add support for a HuggingFace model in XTuner so it can be trained with the + full set of parallelism strategies and optimization switches. The user can + provide a HuggingFace Hub repo id, a local model directory, or a model family + already supported by `transformers`. This skill walks through the full path: + locating the reference implementation, classifying the model, splitting new + code across the model / module / ops layers, implementing the XTuner model + class and config, registering the entry point, validating bitwise numerical + parity with HuggingFace, writing regression tests, and enabling the training + optimizations (EP/SP, micro-batch, torch.compile, fp8, activation offload). +--- + +# Add a HuggingFace model to XTuner + +A complete handbook for porting any HuggingFace causal LM into XTuner with +correct training behavior. Covers four buckets: **Dense LLM**, **MoE LLM**, +**VLM / compose model**, and **`trust_remote_code` model**. Every step below is +grounded in concrete `file:line` references in this repository — read the cited +code before writing yours. + +> **Anti-spaghetti reminder** — XTuner separates *what the model is* (`*Config`), +> *what it computes* (`Dense` / `MoE` / `BaseComposeModel` subclass), and *how its +> weights map to HF safetensors* (`to_hf_key_list`, optional `safetensors_to_params` +> / `param_to_safetensor`). Keep new code inside those three seams. **Never** add +> `if model_name == "..."` branches to the base classes. + +--- + +## 0. Locate the source code + +The only goal of this step is to find the **HuggingFace reference implementation +you will port from**, and to note the `model_type` you will register under. Do +not pick attention/router/layer here — that happens naturally in §1 as you read +the code. + +1. If the user gives a **remote-code** model, the modeling code ships in the repo + as `modeling_*.py` — read it directly. +2. If the user gives **weights only**, confirm the target Python environment with + the user, then find the matching implementation inside that env's + `transformers` install. If no code exists for that `model_type`, stop and ask + the user. +3. Confirm the weights are **training-suitable** — at least bf16 — so training + precision can be aligned later. Flag fp8/int4-only checkpoints to the user. +4. Note two strings from the HF config — you will need both downstream: + ```python + from transformers import AutoConfig + cfg = AutoConfig.from_pretrained(path, trust_remote_code=True) + cfg.model_type # registration key (§4) + cfg.architectures[0] # the modeling class you are mirroring + ``` + +--- + +## 1. How to do the migration + +XTuner splits a model into three layers. You port by reusing what already exists +at each layer and only adding what is genuinely new — so the work is "read the HF +modeling code, then match each of its pieces to the XTuner component below". The +locations are listed plainly; you do not need extra guidance to map your model's +attention / router / layer onto the matching one. + +- **model layer** (`xtuner/v1/model`) — the model body. + - Dense body: `xtuner/v1/model/dense/dense.py` + - MoE body: `xtuner/v1/model/moe/moe.py` + - Multi-modal (compose) bodies: `xtuner/v1/model/compose/` — reuse the language + model and add only the ViT / projector pieces. +- **module layer** (`xtuner/v1/module`): + - Decoder layers: `DenseDecoderLayer`, `MoEDecoderLayer` — + `xtuner/v1/module/decoder_layer`. Small changes add a switch on the existing + layer; a genuinely novel architecture may add `{model_name}_decoder_layer`. + - Attention configs: `MHAConfig` (`xtuner/v1/module/attention/mha.py`), + `MLAConfig` (`.../mla.py`, DeepSeek-style latent), + `GatedDeltaNetConfig` (`.../gated_deltanet.py`, linear). + - Router configs (MoE): `GreedyRouterConfig` + (`xtuner/v1/module/router/greedy.py`, softmax top-k — Qwen3 MoE, GptOss), + `NoAuxRouterConfig` (`xtuner/v1/module/router/noaux_router.py`, grouped sigmoid + aux-loss-free — DeepSeek V3). + - Also: rope, rms_norm, etc. +- **ops layer** (`xtuner/v1/ops`) — kernels such as attention and rms_norm. + +### Existing models to copy from + +Pick the one whose attention + (router) match yours; the closer it is, the +smaller your diff. + +| Your model is… | Copy from | +|--------------------------------------|--------------------------------------| +| Dense LLM | `xtuner/v1/model/dense/qwen3.py` | +| MoE LLM | `xtuner/v1/model/moe/qwen3.py` | +| MoE with MLA + aux-loss-free router | `xtuner/v1/model/moe/deepseek_v3.py` | +| MoE needing weight-layout transform | `xtuner/v1/model/moe/gpt_oss.py` | +| Compose / VLM | `xtuner/v1/model/compose/qwen3_vl/` | + + +After you have a plan, **write an integration design doc into `docs/design/model`** +and confirm with the user that the decomposition is reasonable before +implementing. + +### File layout for the new model + +| Bucket | Model file | Public re-export | Test file | +|----------------------|-------------------------------------------------------------|-----------------------------------------------|--------------------------------------------| +| Dense LLM | `xtuner/v1/model/dense/.py` | `xtuner/v1/model/__init__.py` | `tests/model/test__dense.py` | +| MoE LLM | `xtuner/v1/model/moe/.py` | `xtuner/v1/model/__init__.py` | `tests/model/test__moe.py` | +| Compose / VLM | `xtuner/v1/model/compose//...` | `xtuner/v1/model/__init__.py` | `tests/model/test_.py` | + +--- + +## 2. Implement the model class + +Subclass `Dense`, `MoE`, or `BaseComposeModel`. You typically only need to +override one method. + +### 2.1 `to_hf_key_list(self, key: str) -> list[str]` — **mandatory, abstract on `BaseModel`** + +Declared at `xtuner/v1/model/base.py`. Translates an **XTuner-side** parameter +name into one or more **HF-side** safetensors keys. Return-list cardinality: + +- **1 → 1** for plain weights (most params). +- **1 → N** for fused MoE experts (one fused param explodes into N per-expert + HF keys; see `xtuner/v1/model/moe/qwen3.py`). + +Patterns you almost always need (study `xtuner/v1/model/dense/qwen3.py` +and `xtuner/v1/model/moe/qwen3.py`): + +```python +# 1. Tied embeddings: redirect lm_head → embed_tokens before adding the model. prefix. +if self.config.tie_word_embeddings and "lm_head" in key: + key = key.replace("lm_head", "embed_tokens") + +# 2. Add the "model." prefix that HF wraps everything in. +if "layers" in key or "embed_tokens" in key: + key = "model." + key + +# 3. HF MoE nests experts/gate under .mlp.; XTuner does not. +if "layers" in key: + key = re.sub(r"layers\.(\d+)\.(experts|gate)", r"layers.\1.mlp.\2", key) + +# 4. Top-level norm. +if key.startswith("norm."): + return [key.replace("norm.", "model.norm.")] + +# 5. MoE expert fusion: one fused param → N HF keys. +if "fused_w1w3.weight" in key: + out = [] + for i in range(self.config.n_routed_experts): + out.append(key.replace("fused_w1w3.weight", f"{i}.gate_proj.weight")) + out.append(key.replace("fused_w1w3.weight", f"{i}.up_proj.weight")) + return out + +# 6. Model-specific buffers (e.g. FoPE: rotary_emb.sin_coef / cos_coef are persistent buffers). +if key.startswith("rotary_emb."): + return [key.replace("rotary_emb.", "model.rotary_emb.")] +``` + +A prefix-region remap that differs by *deployment* (e.g. a text tower nested under a +VLM as `model.language_model.`) does **not** belong here — put it in the config's +`hf_key_mapping`. See §5.1. + +### 2.2 `safetensors_to_params` / `param_to_safetensor` — **optional, only for layout mismatches** + +If the HF storage layout cannot be expressed as “same tensor, different name”, +you must transform the bytes. Canonical example: GptOss stores expert weights as +`(num_experts, hidden_size, expert_dim * 2)` but XTuner fuses them as +`(num_experts * 2 * expert_dim, hidden_size)` — see overrides at +`xtuner/v1/model/moe/gpt_oss.py`. + +Use this hook only when transposing/reshaping is unavoidable. Renames go in +`to_hf_key_list`; numerics go here. + +### 2.3 Do **not** override + +- `build_layers`, `build_embeddings`, `_init_weights` — the base handles them + via the config. Override only if the new model adds a layer type the base + cannot express (rare). +- `forward` — same reasoning. If you find yourself touching it, re-evaluate the + bucket choice; the problem is probably in the config (`attention`, `router`, + `first_k_dense_replace`). + +--- + +## 3. Implement the config + +Two layers: a **base config** that reads HF and a **size-specific subclass** per +released checkpoint. + +### 3.1 Base config — three required members + +Mirror `Qwen3MoEConfig` (`xtuner/v1/model/moe/qwen3.py`) or +`Qwen3DenseConfig` (`xtuner/v1/model/dense/qwen3.py`). + +1. **`build(self) -> `** — return `(self)`. + +2. **`@classmethod from_hf(cls, hf_path) -> Self`** — read the HF config and map + its fields one-to-one onto your config. The field set is whatever the live + config class declares (`TransformerConfig` in `xtuner/v1/model/base.py`, + `MoEConfig` in `xtuner/v1/model/moe/moe.py`); read it, don't work from a + hard-coded list. A few non-obvious traps: + - Use `RopeParametersConfig.from_hf_config(hf_config)` from + `xtuner/v1/module/rope/rope.py`. It already handles both the + `rope_parameters` dict (HF ≥ 5.2.0) and the legacy `rope_scaling` dict + (HF 4.57.x), plus YARN / FoPE / mrope special fields. **Do not** parse + rope yourself. + - For optional HF fields, use `getattr(hf_config, "", )` + rather than `hf_config.` — older checkpoints drop fields. + +3. **`@property hf_config(self) -> | None`** — the inverse of + `from_hf`: every field `from_hf` reads must be re-emitted so `save_hf` + round-trips. + +**Built-in vs. `trust_remote_code` — the only axis that changes `from_hf` / +`hf_config`:** a `trust_remote_code` model is structurally still Dense / MoE / +VLM; the *only* difference is whether `transformers` ships a built-in +`Config` class. Decide which case you are in and follow the matching column +— this is the single place that contract is defined. + +| | Built-in config (e.g. Qwen3) | `trust_remote_code` (no built-in config) | +|---|---|---| +| `from_hf` reads via | `.from_pretrained(hf_path)` | `AutoConfig.from_pretrained(hf_path, trust_remote_code=True)`, then `assert hf_config.model_type == ""` | +| `hf_config` returns | a populated `` re-emitting every field | `None` | +| `save_hf` weight-map | written from `hf_config` | same | +| `save_hf` side files | derived from `hf_config` | falls back to copying `config.json` / tokenizer / `*.py` from `self._hf_path` so the dir stays loadable with `trust_remote_code=True`; module-cache test in the matrix below | + +Compare `Qwen3MoEConfig.hf_config` (built-in) and `Qwen3MoEFoPEConfig.hf_config` +(returns `None`), both in `xtuner/v1/model/moe/qwen3.py`. **Do not** invent a +built-in HF config class for a remote-code model — return `None`. + +**Invariants:** + +- `from_hf` ↔ `hf_config` must round-trip. The `test_save_hf` test (§7) is the + enforcement mechanism — every key in the original index must appear in the + saved index with byte-equal tensors. +- The config must be **pickleable and side-effect free**. No env vars, no file + I/O, no `torch.cuda.*` at construction time. + +### 3.2 Size-specific subclasses + +Add one subclass per released checkpoint that hard-codes the published +dimensions. Mirror `Qwen3MoE30BA3Config` / `Qwen3MoE235BA22Config` in +`xtuner/v1/model/moe/qwen3.py`. Even if only one size ships, define at least one +— it is the entry point for `get_model_config` and for tests. + +--- + +## 4. Register the entry point + +Open `xtuner/v1/model/__init__.py` — all four edits are short: + +1. **Import** — bring in the new config classes. +2. **Alias** — add `"": ()` to `model_mapping`. Lookup + normalizes case and `-`/`_`, so `"my-model-7b"` and `"my_model_7b"` are the + same key. +3. **Dispatch** — add `elif cfg.model_type == "":` to + `get_model_config_from_hf` returning `.from_hf(model_path)`. + The catch-all `raise ValueError(...)` must stay last. +4. **`__all__`** — re-export the new public config names. + +**Never** key dispatch on parameter names, file paths, or architecture strings. +Only `model_type`. + +--- + +## 5. Compose / VLM specifics + +Compose models live under `xtuner/v1/model/compose//` and follow a +three-tower pattern. See `xtuner/v1/model/compose/base.py` and +`xtuner/v1/model/compose/qwen3_vl/qwen3_vl_config.py`. + +### 5.1 Config shape + +`BaseComposeConfig` (`compose/base.py`) carries three sub-configs: + +```python +vision_config: XTunerBaseModelConfig # e.g. Qwen3VLVisionConfig +projector_config: XTunerBaseModelConfig # e.g. Qwen3VLProjectorConfig +text_config: XTunerBaseModelConfig # any Dense/MoEConfig subclass +freeze_vision: bool = False +freeze_projector: bool = False +freeze_language: bool = False +``` + +Reuse existing text-tower configs where possible. `Qwen3VLDense*` delegates to +the Dense path via `xtuner/v1/model/dense/qwen3vl_text.py`. **Do not** copy-paste +the entire dense/MoE config — compose over them. + +**Reused text-tower key prefixes — use `hf_key_mapping`, not `to_hf_key_list`.** A +text tower reused inside a VLM keeps its *standalone* `to_hf_key_list` (it emits +`model.<...>` as if it were a plain LLM). The compose checkpoint nests it one level +deeper (`model.language_model.<...>`). Do **not** teach the text tower's +`to_hf_key_list` about that prefix — that leaks the compose context into a tower that +must also work standalone. Set the remap on the **text config** instead: +`hf_key_mapping = {r"^model\.": "model.language_model."}` (applied in +`_init_load_spec`). Keep *structural* renames in `to_hf_key_list`; the +deployment-dependent prefix region belongs in `hf_key_mapping`. + +### 5.2 Model shape + +`BaseComposeModel` (`compose/base.py`) builds three sub-models: +`vision_tower`, `multi_modal_projector`, `language_model`. + +- **`from_hf(hf_path)`** — delegates to each sub-model with + `strict=False` (vision keys are missing from the language-model state dict + and vice versa) and unions the missing-keys sets. +- **`save_hf(hf_dir)`** — saves each tower with a distinct prefix + (`"model-language"`, `"model-vision"`, `"model-projector"`) and merges the + three `weight_map`s into one `model.safetensors.index.json`. + +### 5.3 `hf_config` for VLMs + +Many VLMs ship without a stable HF top-level config class; `hf_config` returns +`None` (`compose/qwen3_vl/qwen3_vl_config.py`). In that case, `save_hf` +falls back to copying source files from `self._hf_path` (see §3.1). + +### 5.4 Vision-side quirks to watch for + +- `attn_impl` selectable per checkpoint: `"flash_attention" | "flex_attention" | + "eager_attention"`. +- `deepstack_visual_indexes` (e.g. `[8, 16, 24]`) — auxiliary supervision depths + on the vision tower; must round-trip. +- Specialized rope: `rope_type="qwen3_vl"`, `mrope_section=[24, 20, 20]` — + delegated to `RopeParametersConfig.from_hf_config`. +- The projector’s `torch.compile` path is enabled on torch ≥ 2.9.1 (see commit + `f6d74efb`). Tests must validate both compiled and eager outputs match. + +--- + +## 6. Baseline parity — no complex parallelism yet + +First make the model correct on the **simplest execution path**: single rank, no +EP/SP, no `torch.compile`, no fp8, no activation offload. Align precision from +the inside out — ops → decoder layer → full model — reusing XTuner's own modules +at every level. Only once this baseline holds do you add the parallel and +optimization features (§8). + +1. **Reuse, don't reimplement.** Use the existing XTuner modules — MHA, MLA, + GatedDeltaNet, GroupLinear, etc. Do not re-implement them with naive torch. + +2. **Route op differences through `XTUNER_HF_IMPL`.** This env var already exists + (`xtuner/v1/utils/misc.py`) and selects HF-exact ops via per-op selectors: + `xtuner/v1/ops/attn_imp.py::get_attn_impl_fn` (eager attention), + `xtuner/v1/ops/rms_norm/__init__.py::get_rms_norm_fn` (native torch rms_norm), + and `xtuner/v1/ops/gated_deltanet/__init__.py::{get_chunk_gated_delta_rule_fn, + get_causal_conv1d_fn}` (return the canonical fla / causal_conv1d wrappers HF + calls instead of XTuner's compile-friendly `torch.library.custom_op` wraps — + same forward output, but the backward op graph now matches HF). Selectors + read the env var **live** so a test can toggle it per model instance. If your + model introduces a new op whose fast path is not bitwise against HF (typical + for fla-backed linear-attention variants), add its HF-exact branch in the + corresponding op selector the same way. **You may only patch at the ops + level** — never take the shortcut of patching the entire + `XTunerDecoderLayer.forward`. + + Pair `XTUNER_HF_IMPL` with `XTUNER_DETERMINISTIC=true`. The latter gates an + autotune-pin block in `xtuner/v1/__init__.py` that monkey-patches + `triton.autotune` to lock every kernel to its first config and disable the + result cache. Without the pin, fla's `@triton.autotune` picks tiling / + reduction order per kernel and the choice drifts between runs, producing + 1 ULP per linear-attention layer that the chain rule amplifies to ~1.0 at + the model boundary (§12). `tests/conftest.py` sets the env var and runs + `import xtuner.v1` so the patch installs before fla is imported. The patch + lives on the `xtuner.v1.*` import chain (not in the conftest) because + `multiprocessing.spawn` children of `MultiProcessTestCase` don't go through + pytest's conftest — they re-import the test class top-level (which pulls + `xtuner.v1.model.*`) and the patch must run in those processes too. + Production training doesn't set the env var, so the block is inert there. + +3. **Decoder-layer bitwise parity — forward AND backward, both required.** Build + one XTuner `DecoderLayer` (one per distinct layer type — e.g. a linear and a + full layer for a hybrid model) and the matching HF layer, set `XTUNER_HF_IMPL` + and `XTUNER_DETERMINISTIC`, feed identical inputs, and align **bitwise** — not a + tolerance. **Backward is not optional**: a `single layer → final norm → lm_head + → CE loss → loss.backward()` test asserts the **layer output, the loss, and the + input gradient `dL/dx`** all bitwise-equal to HF — test one full and one linear + layer separately (reference: `test_decoder_layer_bitwise_parity`). A bitwise + *forward* does not imply a bitwise backward — the forward values can round + identically while the backward graph differs (see the fp32-softmax note in §12). + This is the **primary** path, required at all scales, and the only feasible one + for large models (it never loads the full model — see "load only what you use" + below). When you hand-build a single layer's + inputs, the two sides' `causal mask` / `seq_ctx` must match exactly — see §12. + **Load only what you use** so the test fits any model size: build the XTuner + tower on `meta` and materialize (`to_empty` + selective load) *only* the tested + layer + `norm` + `lm_head` via `HFCheckpointLoader.load(key)` (it reads just the + needed safetensors shard); build a *standalone* HF `DecoderLayer` + norm + + lm_head and load the same keys. Never `from_pretrained` the whole model for a + single-layer test. (`build_rotary_embedding` builds the rotary on CPU with real + buffers even under meta, so `model.rotary_emb` is usable without materializing + the layers.) For < 40B you may also read per-layer hidden states from one + full-model forward (HF `output_hidden_states` + XTuner `return_hidden_states`); + mind the off-by-one in §12. + +4. **Model config matches HF hyperparameters** — see the §3.1 built-in vs. + remote-code contract for how `from_hf` / `hf_config` differ between the two. + +5. **Full-model forward parity for models < 40B**, bitwise against HF, on the + single-rank path (FSDP on or off; no EP / compile / fp8 yet). + +6. **`save_hf` / `from_hf` correctness** (round-trip; see §7). + +7. **Loss-convergence trace + engine test.** Once model forward and decoder-layer + forward/backward are aligned, record a loss-convergence trajectory and add an + engine test case. + +**Commit the baseline — do not stop to ask.** Committing is part of this +workflow: invoking this skill authorizes the in-workflow commits. Once parity +holds on the simple path and lint/tests pass, commit immediately so this correct +baseline is preserved before any parallel/optimization work begins (the §8 +sub-agents branch off this commit). Commit each logical step as you finish it +(§10) rather than batching everything to the end or waiting for a separate +go-ahead. + +Report back to the user: whether bitwise parity was achieved, and the residual +error (decoder-layer level and model level) once XTuner's internal components are +used. + +--- + +## 7. Baseline tests — add these first + +Before any parallel feature, add the most basic regression tests so the baseline +(§6) is locked and the §8 work has a fixed reference to measure against. Add +`tests/model/test__.py`, mirroring +`tests/model/test_qwen3_moe.py` (MoE) or `tests/model/test_qwen3_dense.py` +(Dense). All tests extend `xtuner._testing.DeterministicDDPTestCase` for +deterministic distributed setup. + +Reference test cases: +- decoder-layer bitwise parity + save_hf round-trip: `tests/model/test_qwen3_5_dense.py` + (hybrid linear/full dense VLM; toggles `XTUNER_HF_IMPL` per case) +- model forward & save & load parity: `tests/model/test_qwen3_moe.py` +- engine training test: `tests/engine/test_moe_train_engine.py` + +### 7.1 Baseline test matrix + +All single execution path — no EP / dispatcher / compile / fp8 here; those are +§8 tests, written against this baseline. + +| Case | Dense | MoE | Compose | Notes | +|-------------------------------------------------|:-----:|:---:|:-------:|------------------------------------------------------------------------------------| +| Decoder-layer parity — output + loss + `dL/dx` | ✅ | ✅ | ✅ | **Required at all scales, bitwise.** Test one full and one linear layer *separately*. Standalone & memory-light: build the XTuner tower on `meta` and `materialize_xtuner_submodule` only {tested layer, `norm`, `lm_head`}; build a standalone HF `DecoderLayer`; load via `HFCheckpointLoader`. Under `XTUNER_HF_IMPL` assert the **layer output, loss, and `dL/dx` all bitwise** (forward AND backward at the layer level — a bitwise forward ≠ bitwise backward, see §12). Reference: `test_decoder_layer_bitwise_parity`. | +| Whole-model forward + backward parity (by scale) | ✅(4B)| scale | ✅(4B) | Run the **top-level model** end-to-end vs HF, eager both sides (`XTUNER_HF_IMPL` + `XTUNER_DETERMINISTIC` + HF `eager`). For a **VLM** that's the compose model on an **image** prompt (vision + projector + text — subsumes "VLM forward parity"); for a plain LLM, the text/LM model. With both env switches on plus the autotune pin loaded (§6.2), **forward AND backward are bitwise**: assert logits + loss + `dL/d(pixel_values)` / `dL/d(inputs_embeds)` all `== 0.0`. **Scale**: small = fwd+bwd; medium = forward only; large (can't e2e-forward) = skip → integration. Reference: `test_vl_forward_parity`. | +| Vision-tower bitwise parity | — | — | ✅ | **Compose/VLM only.** Compare HF's `visual` output (`pooler_output`, which is **post-merger**) bitwise (0.0), both eager. XTuner splits that boundary into `vision_tower` (patches → merged hidden) + `projector` (the merger MLP → text dim), so load **both** standalone to match it — don't build the whole VLM just for the vision path. Build them on the **real device, not `meta`**: the vision rotary is computed into a buffer at forward time, so `to_empty` would leave garbage (unlike the text rotary, which `build_rotary_embedding` builds real even under meta — §6.3). Vision attention is *non-causal* → `eager_attention(causal=False)`, which `XTUNER_HF_IMPL` selects. Reference: `test_vision_tower_bitwise_parity`. | +| FSDP forward parity | ✅ | ✅ | ✅ | Looser tolerance (≈ 3e-2 is the established budget). | +| `save_hf` round-trip (byte-equal tensors) | ✅ | ✅ | ✅ | See `tests/model/test_qwen3_moe.py`. | + +### 7.2 Parity bar by model scale + +Two independent axes: + +**Per-layer (op) parity — required at _all_ scales, bitwise.** Test one full and +one linear decoder layer separately; for each assert layer **output + loss + `dL/dx`** +bitwise vs HF under `XTUNER_HF_IMPL`. Keep it memory-light (XTuner `meta`-build + +materialize only the tested layer + `norm` + `lm_head`; standalone HF layer; load +via `HFCheckpointLoader`) so it runs even for models too large to forward +end-to-end. This is the bitwise guarantee for both forward and backward ops. + +**Whole-model (integration) parity — on the _top-level_ model, graded by what fits +one GPU.** Run the real model end-to-end: for a VLM that's the **compose model on an +image prompt** (one test covers vision + projector + text — don't add a separate +text-tower model test); for a plain LLM, the text/LM model. +- **Small** (e.g. dense 4B): both forward (logits/loss) AND backward + (`dL/d(pixel_values)` for a VLM, `dL/d(inputs_embeds)` for text) **bitwise**. + Prerequisites: `XTUNER_HF_IMPL=true` so every op selector returns the + HF-canonical callable (§6.2) AND `XTUNER_DETERMINISTIC=true` so Triton + autotune is pinned (§6.2). With both on, even hybrid linear/full models match + HF byte-for-byte through the whole compose chain. Don't check weight + gradients — they accumulate across positions + tied `lm_head` and aren't the + right granularity for parity. +- **Medium**: forward only (logits/loss bitwise). +- **Large** (cannot forward end-to-end on the test GPUs): skip the whole-model + test; leave it to integration tests. The per-layer test still gives op-level + forward+backward bitwise. + +### 7.3 Required test idioms + +`tests/model/test_qwen3_5_dense.py` is the reference template; extend +`DeterministicDDPTestCase` (`xtuner._testing`), which provides the shared +parity scaffolding so the test file only holds model-specific logic: + +- `with self.hf_impl():` — sets/restores `XTUNER_HF_IMPL=true`; build the XTuner + model **inside** the block so its attention picks the eager op at construction. +- `self.materialize_xtuner_submodule(model, submodule, prefix, loader)` — for a + model built on `meta`, materialize *only* that submodule on CUDA and load its + weights (so a single-layer test fits any model size). +- `self.load_params_from_hf(module, key_for, loader)` / + `self.xtuner_ckpt_key(model, name)` — selective per-param load by checkpoint key + (`HFCheckpointLoader` reads only the needed shard); `xtuner_ckpt_key` applies + `to_hf_key_list` + the config's `hf_key_mapping`. + +```python +# Construct on meta, materialize on device — single-rank baseline, eager, no EP. +with torch.device("meta"): + cfg = get_model_config_from_hf(hf_model_path) + cfg.compile_cfg = False # baseline tests run eager + model = cfg.build()._to_device_dtype(dtype=torch.bfloat16, skip_buffers_dtype=True) + +# Optionally shard with FSDP (data-parallel only), then load. +model.fully_shard(fsdp_config=FSDPConfig(cpu_offload=False)) +model.from_hf(hf_model_path) +``` + +Older `patch_hf_rms_norm` / `patch_hf_rope` helpers also live in +`xtuner._testing`. **Check they actually fit your model before using them** — +they were written for earlier models and match HF modules +by class-name substring and attribute name (e.g. `variance_epsilon`). On a newer +model they can silently mis-patch (wrong eps attribute, or replacing a +`zero_centered` RMSNorm with a default one). Often the native XTuner ops already +match HF bitwise (see §12), so no patching is needed; prefer `XTUNER_HF_IMPL` +over these helpers for parity. The §8 parallel-feature tests reuse this same +idiom with `cfg.dispatcher` / `cfg.ep_size` (and `FSDPConfig(ep_size=...)`) set. + +### 7.4 Checkpoint paths + +Read from env vars; **never** hard-code. Example: +`QWEN3_MOE_PATH = os.environ["QWEN3_MOE_PATH"]` at +`tests/model/test_qwen3_moe.py`. Document new env vars in the PR body. + +### 7.5 Drop-in training config + +Ship a runnable training config at `ci/config/.py` so a developer (and CI) can verify +the port end-to-end with one command: load the HF checkpoint, run a few real training steps, +watch the loss curve descend as expected. This is the smoke test that catches anything the +unit tests miss — the full `from_hf` → `model.forward` → loss → backward → optimizer step → +FSDP shard/reduce chain — and the file doubles as the example users copy when wiring the +model into their own training pipeline. Mirror `ci/config/qwen3_moe_30BA3.py` (MoE) or +`ci/config/qwen3_dense.py` (dense), keeping its structure: one `()` (from §3.2) +fed into a `TrainerConfig` alongside `optim_cfg` / `lr_cfg` / `fsdp_cfg` / `dataset_cfg` / +`dataloader_cfg` / `loss_cfg`. `load_from` and `tokenizer_path` read from an env var (§7.4 — +typically the same one as the parity test). Verify by running ~50 steps and confirming the +loss drops monotonically into a plausible range for that model size; record the trajectory +in the PR body alongside the §6 convergence trace. + +--- + +## 8. Training optimizations + +Only after the §6 baseline is committed and the §7 tests are green: verify each +switch runs correctly and keeps precision within budget **against that baseline**. + +1. EP and SP can be enabled normally. +2. With `ep_size > 1`, `intra_layer_micro_batch` enables correctly and stays + precise. +3. `torch.compile` runs correctly. +4. fp8 can be enabled. +5. `XTUNER_ACTIVATION_OFFLOAD` can be enabled. + +Each switch needs a test that compares its output to the §6 baseline — e.g. the +dispatcher × ep_size parity matrix (at minimum `{(None,1), ("all2all",4), +("all2all",8)}`, add `"deepep"` if supported), compiled-vs-eager, or fp8-vs-bf16 +within tolerance. + +These items are well-suited to the multi-agent workflow in §11: the main agent +analyzes and creates tasks (with test cases), sub-agents develop in parallel, and +the main agent merges. Report the precision error introduced by each switch. + +--- + +## 9. Reference cheat-sheet + +| Concept | Concrete location | +|-------------------------|--------------------------------------------------------------------------------| +| Dense base config | `xtuner/v1/model/base.py` (`TransformerConfig`) | +| MoE base config | `xtuner/v1/model/moe/moe.py` (`MoEConfig`) | +| Compose base config | `xtuner/v1/model/compose/base.py` (`BaseComposeConfig`) | +| `from_hf` (model) | `xtuner/v1/model/base.py` (`from_hf`) | +| `save_hf` (model) | `xtuner/v1/model/base.py` (`save_hf`) | +| `to_hf_key_list` proto | `xtuner/v1/model/base.py` | +| MoE layer split | `xtuner/v1/model/moe/moe.py` (`first_k_dense_replace`) | +| RoPE auto-parse | `xtuner/v1/module/rope/rope.py` (`RopeParametersConfig.from_hf_config`)| +| Dense reference | `xtuner/v1/model/dense/qwen3.py` | +| MoE reference | `xtuner/v1/model/moe/qwen3.py` | +| MoE w/ MLA + NoAux | `xtuner/v1/model/moe/deepseek_v3.py` | +| MoE w/ layout transform | `xtuner/v1/model/moe/gpt_oss.py` | +| Remote-code reference | `xtuner/v1/model/moe/qwen3.py` (`Qwen3MoEFoPEConfig`) | +| Compose reference (VLM) | `xtuner/v1/model/compose/qwen3_vl/` | +| Dense test reference | `tests/model/test_qwen3_dense.py` | +| MoE test reference | `tests/model/test_qwen3_moe.py` | +| Engine test reference | `tests/engine/test_moe_train_engine.py` | + +--- + +## 10. Commit discipline + +0. **Commit as you go — this is authorized, do not pause for a separate + go-ahead.** The user invoking this skill authorizes the commits this workflow + produces. Finish a logical step → lint/test → commit, then move on. Do not + leave the whole port uncommitted and end with "should I commit?"; that strands + the work. (Hard-to-reverse or outward actions like `git push` / opening a PR + still need explicit confirmation — committing locally does not.) +1. **Follow the stacked-PR convention.** Plan upfront how many commits the model + port needs — one per logical step (e.g. shared-path fixes; model class + + config + registration; the `XTUNER_HF_IMPL` switch; baseline tests; then one + per parallel/optimization feature). Order them by dependency (a shared-path fix + the baseline relies on goes first). Land every later change as + `git commit --fixup=` + `git rebase --autosquash` into the commit it + belongs to. **Never grow the history with endless patch commits just to keep + fixing things up.** +2. **Every commit must pass lint:** + ```bash + pre-commit run --files $(find xtuner/v1) + ``` + +--- + +## 11. Multi-agent workflow + +This applies **only to the §8 parallel/optimization features** — and only once +the §6 baseline and §7 baseline tests are done and committed (§10). The baseline +itself is built sequentially, not multi-agent; the §8 features are independent of +each other, which is what makes them safe to develop in parallel. + +1. The **main agent** splits the §8 features into tasks and, per §8, writes each + feature's comparison test against the committed baseline and commits the tests. +2. **Sub-agents** branch off the baseline/test commit and develop one feature + each (e.g. one switch / one tower) in parallel, making its test pass. +3. **The two sides adjust each other** — a shared-interface change is made once by + the main agent (it updates the test/baseline; sub-agents rebase), and a failing + test pushes the sub-agent to fix its feature, not to weaken the test. +4. The **main agent** merges, resolves conflicts, and reports the precision error + each feature introduces relative to the baseline. + +--- + +## 12. Parity debugging pitfalls + +Hard-won notes from aligning hybrid (GatedDeltaNet + gated-MHA) models bitwise. +The meta-rule: **measure on real GPU and bisect to localize the divergence before +concluding anything** — most "bugs" here are measurement artifacts, not model +errors. A clean way to bisect is to feed one side's intermediate tensors into +both modules and compare step by step (norm → q/k/v → qk_norm → rope → attention +core → o_proj). + +- **Single-layer parity: the two sides' inputs must match exactly — especially + the causal mask.** When you hand-build one layer's inputs, passing + `attention_mask=None` to an HF attention/layer makes HF apply **no** causal + mask while XTuner stays causal, producing a large *false* diff. Build a causal + mask equivalent to XTuner's `seq_ctx`, or drive the layer through the model so + the mask is constructed internally. + +- **`flash` ≠ `eager`, and HF eager upcasts softmax to fp32.** XTuner defaults to + flash attention; HF reference eager does softmax in fp32. They are numerically + close but not bitwise equal, so bitwise parity requires forcing eager on the + XTuner side (that is exactly what `XTUNER_HF_IMPL` does). XTuner's + `eager_attention` matches HF's `eager_attention_forward` bitwise. + +- **A bitwise forward does not imply a bitwise backward.** The eager softmax must + run in **fp32** (`softmax(..., dtype=torch.float32).to(x.dtype)`, like HF), not + bf16. With a bf16 softmax the forward logits can still round identically to HF + (bitwise), but the *backward* differs (~1e-5 on `dL/dx`) because the gradient + accumulates at bf16 precision. So always check backward bitwise separately + (the `loss.backward()` test), not just forward — and align op dtypes (softmax, + norms) in the backward-sensitive direction, fp32 where HF uses fp32. + +- **Whole-model backward is bitwise — but only with the full env stack.** + Per-layer bitwise backward does not on its own imply whole-model bitwise + backward; the missing pieces are `XTUNER_HF_IMPL=true` AND + `XTUNER_DETERMINISTIC=true` AND the conftest must `import xtuner.v1` so the + autotune-pin block in `xtuner/v1/__init__.py` runs before any test reaches + fla. Without those, two distinct ULP-level sources compound through the LM + chain rule and look indistinguishable from "fundamental bf16 noise": + + 1. **Triton autotune drift in fla.** fla's `@triton.autotune` picks + tiling / num_warps / reduction-order per kernel. Across runs the choice + can change, producing a ~1-ULP backward diff per linear-attention layer + that the chain rule amplifies — a 32-layer LM stack turned a per-layer + ~2e-3 abs diff at grad magnitude ~30 into a ~1.0 diff on the + massive-activation channels of `dL/d(pixel_values)`. The pin (configs + held to the first one, cache disabled) is in `xtuner/v1/__init__.py` + under an `XTUNER_DETERMINISTIC` gate. It must run **before any module + imports fla** — fla's `@triton.autotune` decorators evaluate at import + time, so patching after the fact has no effect on kernels already + constructed. That's why the patch is on the `xtuner.v1.*` import chain + rather than in the conftest: `multiprocessing.spawn` children of + `MultiProcessTestCase` re-import the test class top-level (which pulls + `xtuner.v1.model.*`) without running conftest, so the conftest alone + can't install the patch in those processes. + 2. **`torch.library.custom_op` wraps for the linear-attention path** + (`GatedDeltaNet`, `causal_conv1d`) produce a different backward autograd + graph from the canonical fla / causal_conv1d call HF uses, even though + the underlying kernel and forward output are identical. `XTUNER_HF_IMPL` + routes through `get_chunk_gated_delta_rule_fn` / `get_causal_conv1d_fn` + (§6.2) so the parity path matches HF byte-for-byte. + + Production train doesn't set either env var; the custom_op wraps + autotune + still operate. These switches are **parity-test only**. If the whole-model + backward is not bitwise but the per-layer backward is, suspect one of the + two before suspecting depth-accumulated bf16 noise. + +- **`flash`(XTuner varlen) vs `flash_attention_2`(HF) is bitwise only for short, + single-tile sequences.** They are the same kernel, so at a short seqlen (e.g. + 32) the logits are bitwise (0.0); at longer sequences the tiling / + accumulation order differs and they drift ~1e-2. So a short-seq text test can + assert flash-vs-FA2 bitwise, but anything realistic-length (or a VLM whose + image tokens make the sequence long) needs the **eager** path for + seqlen-independent bitwise. Don't conclude "flash matches FA2" from one short + input. + +- **Shared kernels are already bitwise; focus on what isn't.** GatedDeltaNet / + linear attention go through the same `fla` / `causal_conv1d` kernels as HF, and + native rms_norm / rope match HF bitwise — these come out at 0.0 with no work. + Spend your effort on the path that actually differs (usually full attention). + +- **`output_hidden_states` off-by-one (full-forward route, < 40B only).** HF's + tuple is `[emb, L0_out, …, L{n-2}_out, post_norm]`: index `i+1` is layer `i`'s + pre-norm output only for `i ≤ n-2`; the **last** entry is post-norm, not the + last layer's pre-norm output. Compare the last layer via `model.norm(...)` or + you will see the final layer "explode". + +- **A big max-diff is not automatically a bug.** Late layers have *massive + activations* (a few channels in the tens/hundreds); a tiny relative error there + shows up as a large absolute max-diff. Look at the mean and the trend, not just + the max. Conversely, feeding `randn` inputs to a single attention layer is + ill-conditioned and misleads — use real (or model-produced) hidden states. + +- **A single-model failure is often a general-path bug.** E.g. `Dense.build_layers` + not forwarding `rms_norm_type` to `DenseDecoderLayer` broke any `zero_centered` + dense model, not just this one. Fix the shared invariant, not the symptom. + +- **Loading a standalone sub-tower uses `strict=False`.** A text tower loaded + directly from a full VLM checkpoint will see vision (and deferred `mtp.*`) keys + as unexpected; load with `strict=False` so they are ignored. + +- **Test the whole VLM, not just the text tower.** Text-tower parity leaves the + vision tower and projector completely untested. A compose port needs a + full-model forward-parity test with real image inputs (§7.1). The text path of + the compose forward still runs a *dummy* vision forward to keep the tower in the + autograd graph, so vision bugs surface even on text prompts. + +- **Vision attention is non-causal — give `eager_attention` a `causal=False` + path.** XTuner's `eager_attention` defaulted to a causal block-diagonal mask; + that is wrong for a vision tower (bidirectional within each image), which is why + forcing vision to the causal eager op produced large diffs. The op now takes + `causal=True|False` (False → non-causal block-diagonal mask), the vision module + already calls it with `causal=False`, and the vision tower routes through + `get_attn_impl_fn` so `XTUNER_HF_IMPL` selects eager. With that, the *whole* VLM + is bitwise in one run, everything eager: `XTUNER_HF_IMPL` on the XTuner side + (text causal-eager + vision non-causal-eager) and HF loaded with a single + `attn_implementation="eager"`. Prefer eager for parity — it is + seqlen-independent, unlike flash/FA2 (see the flash seqlen bullet above). + +- **Vision pos-embed dtype when reusing the Qwen3-VL tower.** HF's + `fast_pos_embed_interpolate` returns fp32 (bilinear weights are fp32) and HF's + vision forward casts it (`pos_embeds.to(hidden_states.dtype)`); XTuner's vision + forward adds it without a cast. If you patch in the HF method for parity, cast + the result back to the param dtype, or you get an fp32/bf16 `LayerNorm` mismatch + (`expected scalar type Float but found BFloat16`). + +- **Hybrid-model `torch.compile`: choose `fullgraph` per layer — verify, then fall + back; don't gate *whether* to compile.** When a decoder layer's forward is wrapped + into a higher-order op (HOP) — e.g. by an activation-checkpoint wrapper, which + `torch.compile` lowers to a HOP whose body must be functional so it can be + recomputed in backward — *and* that layer is a linear-attention layer + (GatedDeltaNet), the layer mutates an outer-scope variable inside the HOP + (GatedDeltaNet writes back into `seq_ctx`). Under `fullgraph=True` that raises + `HigherOrderOperator ... Mutating a variable not in the current scope (SideEffects)`. + This is **not** a blanket "linear attention ⇒ `fullgraph=False`" rule — whether it + triggers depends on whether the layer is actually HOP-wrapped and whether that + linear impl really writes outer state. So for such a HOP + linear combination, + **first test whether `fullgraph=True` runs**; keep it if it does (it compiles more + of the layer), and only **fall back to `fullgraph=False`** if it doesn't — + `fullgraph=False` survives by graph-breaking at the mutation and running that + side-effecting write in eager. Make this a **per-layer** choice in + `Dense.fully_shard` (decide each layer's `fullgraph`), not a decision about whether + a layer is compiled at all. + From 31819460ebeee65668c3f2af63176d39928a19f2 Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Fri, 29 May 2026 18:31:44 +0000 Subject: [PATCH 7/8] [Feature] Pin Triton autotune for fla determinism via XTUNER_DETERMINISTIC MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When XTUNER_DETERMINISTIC=true is set, monkey-patch triton.autotune to pin every autotune to its first config and disable the result cache. Mirrors PR #1850 — keep this as a single revertable commit so it can be replaced when that PR lands and we no longer carry the patch locally. Why xtuner/v1/__init__.py and not tests/conftest.py: fla's @triton.autotune decorators evaluate at import time, and any code path that loads a fla-backed model triggers `xtuner.v1.model.*` imports before fla is reached — including multiprocessing.spawn children of MultiProcessTestCase, which do not run pytest's conftest. The patch must be installed by something on the xtuner.v1.* import chain. Production train doesn't set XTUNER_DETERMINISTIC, so the patch is inert there. tests/conftest.py sets the env var (so subsequent xtuner.v1 imports trigger the patch) and force-imports xtuner.v1 in the pytest parent process; spawn children inherit the env and re-run the same __init__ block. --- tests/conftest.py | 13 ++++++++++++- xtuner/v1/__init__.py | 41 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5410e9ccf..a538ce3f5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,7 +3,18 @@ import sys from pathlib import Path -from huggingface_hub import constants +# Activate the Triton autotune pin installed by `xtuner.v1.__init__` (gated by this env +# var). The pin must run before any module imports `fla`; see the patch's docstring in +# `xtuner/v1/__init__.py` for why it lives there rather than here in the conftest. +os.environ.setdefault("XTUNER_DETERMINISTIC", "true") + +# Trigger that patch in the pytest parent process. Each `MultiProcessTestCase` child is a +# fresh `multiprocessing.spawn` Python that re-imports the test class top-level, which pulls +# `xtuner.v1.*` and runs the same `xtuner/v1/__init__.py` block — so the patch is installed +# in every process as long as `XTUNER_DETERMINISTIC=true` was inherited via the env. +import xtuner.v1 # noqa: E402,F401 + +from huggingface_hub import constants # noqa: E402 _HF_DYNAMIC_MODULE_PREFIX = "transformers_modules" diff --git a/xtuner/v1/__init__.py b/xtuner/v1/__init__.py index 4982a2d30..26e69d263 100644 --- a/xtuner/v1/__init__.py +++ b/xtuner/v1/__init__.py @@ -1 +1,40 @@ -from . import patch # noqa: F401 +import os +from typing import Any, cast + + +def _patch_triton_autotune_for_determinism() -> None: + # Pin every Triton autotune to its first config and disable the result cache so fla can't + # pick a different tiling / num_warps / reduction order across runs and break bitwise + # gradient parity. Must run before any module imports `fla` — fla's `@triton.autotune(...)` + # decorators evaluate at import time, so patching `triton.autotune` afterwards has no + # effect on kernels already constructed. Lives here in `xtuner/v1/__init__.py` because + # `xtuner.v1.model.*` imports run before fla in every code path that loads a fla-backed + # model, including `multiprocessing.spawn` children that don't go through pytest's + # conftest (each child re-imports the test class top-level, which pulls `xtuner.v1.model` + # — and therefore this `__init__` — before anything that would touch fla). + # + # This mirrors InternLM/xtuner#1850. Production training does not set + # XTUNER_DETERMINISTIC, so the patch never runs there; remove this block once that PR + # lands and supersedes it. + import triton + + original_autotune = triton.autotune + if getattr(original_autotune, "_xtuner_deterministic_patched", False): + return + + def deterministic_autotune(configs, *args, **kwargs): + if configs: + configs = configs[:1] + kwargs["cache_results"] = False + return original_autotune(configs, *args, **kwargs) + + patched = cast(Any, deterministic_autotune) + patched._xtuner_deterministic_patched = True + patched._xtuner_original_autotune = original_autotune + triton.autotune = deterministic_autotune + + +if os.getenv("XTUNER_DETERMINISTIC") == "true": + _patch_triton_autotune_for_determinism() + +from . import patch # noqa: E402,F401 From b09efc404a2d070c13775485a57ae923466bb76b Mon Sep 17 00:00:00 2001 From: HAOCHENYE <21724054@zju.edu.cn> Date: Sat, 30 May 2026 12:24:36 +0000 Subject: [PATCH 8/8] [CI] update ci model path --- ci/scripts/CI_ENV.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/scripts/CI_ENV.sh b/ci/scripts/CI_ENV.sh index 939df9ec5..7643f09c2 100644 --- a/ci/scripts/CI_ENV.sh +++ b/ci/scripts/CI_ENV.sh @@ -1,6 +1,7 @@ #!/bin/bash export QWEN3_VL_MOE_PATH=${CI_SHARE_MODEL}/Qwen3-VL-30B-A3B-Instruct_MOE export QWEN3_VL_DENSE_PATH=${CI_SHARE_MODEL}/Qwen3-VL-4B-Instruct +export QWEN3_5_DENSE_4B_PATH=${CI_SHARE_MODEL}/models--Qwen--Qwen3.5-4B export INTERN_VL_1B_PATH=${CI_SHARE_MODEL}/InternVL3_5-1B-HF export VIDEO_ROOT=${CI_SHARE_DATA}/images export QWEN3_4B_PATH=${CI_SHARE_MODEL}/Qwen3-4B-Instruct-2507