Skip to content

Commit 15d8949

Browse files
committed
add qwen2.5
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 353652b commit 15d8949

18 files changed

Lines changed: 888 additions & 73 deletions
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Conversion utilities between HuggingFace Qwen2 and TransformerEngine formats."""
17+
18+
import inspect
19+
20+
import torch
21+
from transformers import Qwen2Config, Qwen2ForCausalLM
22+
23+
import state
24+
from modeling_qwen2_te import NVQwen2Config, NVQwen2ForCausalLM
25+
26+
27+
mapping = {
28+
"model.embed_tokens.weight": "model.embed_tokens.weight",
29+
"model.layers.*.input_layernorm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
30+
"model.layers.*.self_attn.o_proj.weight": "model.layers.*.self_attention.proj.weight",
31+
"model.layers.*.post_attention_layernorm.weight": "model.layers.*.layernorm_mlp.layer_norm_weight",
32+
"model.layers.*.mlp.down_proj.weight": "model.layers.*.layernorm_mlp.fc2_weight",
33+
"model.norm.weight": "model.norm.weight",
34+
"lm_head.weight": "lm_head.weight",
35+
}
36+
37+
# Reverse mapping from TE to HF format by reversing the original mapping
38+
reverse_mapping = {v: k for k, v in mapping.items()}
39+
40+
41+
def _merge_qkv_bias(ctx: state.TransformCTX, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
42+
"""Merge separate q, k, v biases into interleave-concatenated qkv bias."""
43+
target_config = ctx.target.config
44+
45+
head_num = target_config.num_attention_heads
46+
num_query_groups = target_config.num_key_value_heads
47+
heads_per_group = head_num // num_query_groups
48+
head_size = target_config.hidden_size // head_num
49+
50+
q = q.view(head_num, head_size)
51+
k = k.view(num_query_groups, head_size)
52+
v = v.view(num_query_groups, head_size)
53+
54+
qkv_bias_l = []
55+
for i in range(num_query_groups):
56+
qkv_bias_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :])
57+
qkv_bias_l.append(k[i : i + 1, :])
58+
qkv_bias_l.append(v[i : i + 1, :])
59+
qkv_bias = torch.cat(qkv_bias_l)
60+
61+
return qkv_bias.reshape(-1)
62+
63+
64+
def _split_qkv_bias(ctx: state.TransformCTX, qkv_bias: torch.Tensor):
65+
"""Split interleave-concatenated qkv bias into separate q, k, v biases."""
66+
target_config = ctx.target.config
67+
68+
head_num = target_config.num_attention_heads
69+
num_query_groups = target_config.num_key_value_heads
70+
heads_per_group = head_num // num_query_groups
71+
head_size = target_config.hidden_size // head_num
72+
qkv_total_dim = head_num + 2 * num_query_groups
73+
74+
qkv_bias = qkv_bias.reshape(qkv_total_dim, head_size)
75+
q_slice = torch.cat(
76+
[
77+
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
78+
for i in range(num_query_groups)
79+
]
80+
)
81+
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
82+
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
83+
84+
q_bias = qkv_bias[q_slice].reshape(-1).cpu()
85+
k_bias = qkv_bias[k_slice].reshape(-1).cpu()
86+
v_bias = qkv_bias[v_slice].reshape(-1).cpu()
87+
88+
return q_bias, k_bias, v_bias
89+
90+
91+
def _zero_bias_from_weight(ctx: state.TransformCTX, weight: torch.Tensor):
92+
"""Create a zero bias with dimension matching the weight's first axis."""
93+
return torch.zeros(weight.shape[0], device=weight.device, dtype=weight.dtype)
94+
95+
96+
def _zero_fc1_bias(ctx: state.TransformCTX, gate: torch.Tensor, up: torch.Tensor):
97+
"""Create a zero fc1 bias for the merged gate+up projection."""
98+
return torch.zeros(gate.shape[0] + up.shape[0], device=gate.device, dtype=gate.dtype)
99+
100+
101+
def convert_qwen2_hf_to_te(model_hf: Qwen2ForCausalLM, **config_kwargs) -> NVQwen2ForCausalLM:
102+
"""Convert a Hugging Face Qwen2 model to a Transformer Engine model.
103+
104+
Args:
105+
model_hf (nn.Module): The Hugging Face model.
106+
**config_kwargs: Additional configuration kwargs to be passed to NVQwen2Config.
107+
108+
Returns:
109+
nn.Module: The Transformer Engine model.
110+
"""
111+
config_dict = model_hf.config.to_dict()
112+
# Ensure layer_types is consistent with num_hidden_layers (from_pretrained can leave stale layer_types)
113+
if len(config_dict.get("layer_types", [])) != config_dict.get("num_hidden_layers", 0):
114+
config_dict["layer_types"] = config_dict["layer_types"][: config_dict["num_hidden_layers"]]
115+
te_config = NVQwen2Config(**config_dict, **config_kwargs)
116+
with torch.device("meta"):
117+
model_te = NVQwen2ForCausalLM(te_config)
118+
119+
if model_hf.config.tie_word_embeddings:
120+
state_dict_ignored_entries = ["lm_head.weight"]
121+
else:
122+
state_dict_ignored_entries = []
123+
124+
output_model = state.apply_transforms(
125+
model_hf,
126+
model_te,
127+
mapping,
128+
[
129+
# Merge Q/K/V weights into fused QKV
130+
state.state_transform(
131+
source_key=(
132+
"model.layers.*.self_attn.q_proj.weight",
133+
"model.layers.*.self_attn.k_proj.weight",
134+
"model.layers.*.self_attn.v_proj.weight",
135+
),
136+
target_key="model.layers.*.self_attention.layernorm_qkv.weight",
137+
fn=state.TransformFns.merge_qkv,
138+
),
139+
# Merge Q/K/V biases into fused QKV bias
140+
state.state_transform(
141+
source_key=(
142+
"model.layers.*.self_attn.q_proj.bias",
143+
"model.layers.*.self_attn.k_proj.bias",
144+
"model.layers.*.self_attn.v_proj.bias",
145+
),
146+
target_key="model.layers.*.self_attention.layernorm_qkv.bias",
147+
fn=_merge_qkv_bias,
148+
),
149+
# Merge gate/up projections into fc1
150+
state.state_transform(
151+
source_key=(
152+
"model.layers.*.mlp.gate_proj.weight",
153+
"model.layers.*.mlp.up_proj.weight",
154+
),
155+
target_key="model.layers.*.layernorm_mlp.fc1_weight",
156+
fn=state.TransformFns.merge_fc1,
157+
),
158+
# TE bias=True creates biases for all linear layers, but Qwen2 only has bias on QKV.
159+
# Initialize the extra TE biases (output projection, MLP) to zero.
160+
state.state_transform(
161+
source_key="model.layers.*.self_attn.o_proj.weight",
162+
target_key="model.layers.*.self_attention.proj.bias",
163+
fn=_zero_bias_from_weight,
164+
),
165+
state.state_transform(
166+
source_key=(
167+
"model.layers.*.mlp.gate_proj.weight",
168+
"model.layers.*.mlp.up_proj.weight",
169+
),
170+
target_key="model.layers.*.layernorm_mlp.fc1_bias",
171+
fn=_zero_fc1_bias,
172+
),
173+
state.state_transform(
174+
source_key="model.layers.*.mlp.down_proj.weight",
175+
target_key="model.layers.*.layernorm_mlp.fc2_bias",
176+
fn=_zero_bias_from_weight,
177+
),
178+
],
179+
state_dict_ignored_entries=state_dict_ignored_entries,
180+
)
181+
182+
output_model.model.rotary_emb.inv_freq = model_hf.model.rotary_emb.inv_freq.clone()
183+
184+
return output_model
185+
186+
187+
def convert_qwen2_te_to_hf(model_te: NVQwen2ForCausalLM, **config_kwargs) -> Qwen2ForCausalLM:
188+
"""Convert a Transformer Engine Qwen2 model to a Hugging Face model.
189+
190+
Args:
191+
model_te (nn.Module): The Transformer Engine model.
192+
**config_kwargs: Additional configuration kwargs to be passed to Qwen2Config.
193+
194+
Returns:
195+
nn.Module: The Hugging Face model.
196+
"""
197+
# Filter out keys from model_te.config that are not valid Qwen2Config attributes
198+
te_config_dict = model_te.config.to_dict()
199+
valid_keys = set(inspect.signature(Qwen2Config.__init__).parameters)
200+
filtered_config = {k: v for k, v in te_config_dict.items() if k in valid_keys}
201+
# Ensure layer_types is consistent with num_hidden_layers
202+
if len(filtered_config.get("layer_types", [])) != filtered_config.get("num_hidden_layers", 0):
203+
filtered_config["layer_types"] = filtered_config["layer_types"][: filtered_config["num_hidden_layers"]]
204+
hf_config = Qwen2Config(**filtered_config, **config_kwargs)
205+
206+
with torch.device("meta"):
207+
model_hf = Qwen2ForCausalLM(hf_config)
208+
209+
output_model = state.apply_transforms(
210+
model_te,
211+
model_hf,
212+
reverse_mapping,
213+
[
214+
# Split fused QKV weight into separate Q/K/V
215+
state.state_transform(
216+
source_key="model.layers.*.self_attention.layernorm_qkv.weight",
217+
target_key=(
218+
"model.layers.*.self_attn.q_proj.weight",
219+
"model.layers.*.self_attn.k_proj.weight",
220+
"model.layers.*.self_attn.v_proj.weight",
221+
),
222+
fn=state.TransformFns.split_qkv,
223+
),
224+
# Split fused QKV bias into separate Q/K/V biases
225+
state.state_transform(
226+
source_key="model.layers.*.self_attention.layernorm_qkv.bias",
227+
target_key=(
228+
"model.layers.*.self_attn.q_proj.bias",
229+
"model.layers.*.self_attn.k_proj.bias",
230+
"model.layers.*.self_attn.v_proj.bias",
231+
),
232+
fn=_split_qkv_bias,
233+
),
234+
# Split fc1 into gate/up projections
235+
state.state_transform(
236+
source_key="model.layers.*.layernorm_mlp.fc1_weight",
237+
target_key=(
238+
"model.layers.*.mlp.gate_proj.weight",
239+
"model.layers.*.mlp.up_proj.weight",
240+
),
241+
fn=state.TransformFns.split_fc1,
242+
),
243+
],
244+
state_dict_ignored_entries=model_hf._tied_weights_keys,
245+
)
246+
247+
output_model.model.rotary_emb.inv_freq = model_te.model.rotary_emb.inv_freq.clone()
248+
output_model.tie_weights()
249+
250+
return output_model
File renamed without changes.
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
2626

27-
import convert
27+
import convert_qwen3
2828
from modeling_qwen3_te import AUTO_MAP
2929

3030

@@ -38,7 +38,7 @@ def export_hf_checkpoint(tag: str, export_path: Path):
3838
model_hf = AutoConfig.from_pretrained(tag)
3939
model_hf = AutoModelForCausalLM.from_config(model_hf)
4040

41-
model_te = convert.convert_qwen3_hf_to_te(model_hf)
41+
model_te = convert_qwen3.convert_qwen3_hf_to_te(model_hf)
4242
model_te.save_pretrained(export_path)
4343

4444
tokenizer = AutoTokenizer.from_pretrained(tag)

0 commit comments

Comments
 (0)