Skip to content

Commit 70d72f8

Browse files
committed
add qwen3 model
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent c589cbb commit 70d72f8

16 files changed

Lines changed: 4124 additions & 0 deletions
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
extend = "../.ruff.toml"

bionemo-recipes/models/qwen3/collator.py

Lines changed: 1036 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Conversion utilities between HuggingFace Qwen3 and TransformerEngine formats."""
17+
18+
import inspect
19+
20+
import torch
21+
from transformers import Qwen3Config, Qwen3ForCausalLM
22+
23+
import state
24+
from modeling_qwen3_te import NVQwen3Config, NVQwen3ForCausalLM
25+
26+
27+
mapping = {
28+
"model.embed_tokens.weight": "model.embed_tokens.weight",
29+
"model.layers.*.input_layernorm.weight": "model.layers.*.self_attention.layernorm_qkv.layer_norm_weight",
30+
"model.layers.*.self_attn.o_proj.weight": "model.layers.*.self_attention.proj.weight",
31+
"model.layers.*.self_attn.q_norm.weight": "model.layers.*.self_attention.q_norm.weight",
32+
"model.layers.*.self_attn.k_norm.weight": "model.layers.*.self_attention.k_norm.weight",
33+
"model.layers.*.post_attention_layernorm.weight": "model.layers.*.layernorm_mlp.layer_norm_weight",
34+
"model.layers.*.mlp.down_proj.weight": "model.layers.*.layernorm_mlp.fc2_weight",
35+
"model.norm.weight": "model.norm.weight",
36+
"lm_head.weight": "lm_head.weight",
37+
}
38+
39+
# Reverse mapping from TE to HF format by reversing the original mapping
40+
reverse_mapping = {v: k for k, v in mapping.items()}
41+
42+
43+
def _merge_qkv(ctx: state.TransformCTX, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
44+
"""Merge q, k, v to interleave-concatenated qkv.
45+
46+
This version uses config.head_dim instead of hidden_size // num_attention_heads,
47+
which is necessary for Qwen3 where head_dim is independently configured.
48+
"""
49+
target_config = ctx.target.config
50+
51+
head_num = target_config.num_attention_heads
52+
num_query_groups = target_config.num_key_value_heads
53+
heads_per_group = head_num // num_query_groups
54+
hidden_size = target_config.hidden_size
55+
head_size = target_config.head_dim
56+
57+
old_tensor_shape = q.size()
58+
new_q_tensor_shape = (head_num, head_size, *old_tensor_shape[1:])
59+
new_kv_tensor_shape = (num_query_groups, head_size, *old_tensor_shape[1:])
60+
61+
q = q.view(*new_q_tensor_shape)
62+
k = k.view(*new_kv_tensor_shape)
63+
v = v.view(*new_kv_tensor_shape)
64+
65+
qkv_weights_l = []
66+
for i in range(num_query_groups):
67+
qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :])
68+
qkv_weights_l.append(k[i : i + 1, :, :])
69+
qkv_weights_l.append(v[i : i + 1, :, :])
70+
qkv_weights = torch.cat(qkv_weights_l)
71+
assert qkv_weights.ndim == 3, qkv_weights.shape
72+
assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape
73+
assert qkv_weights.shape[1] == head_size, qkv_weights.shape
74+
assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape
75+
76+
qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size])
77+
78+
return qkv_weights
79+
80+
81+
def _split_qkv(ctx: state.TransformCTX, linear_qkv: torch.Tensor):
82+
"""Split interleave-concatenated qkv to q, k, v.
83+
84+
This version uses config.head_dim instead of hidden_size // num_attention_heads,
85+
which is necessary for Qwen3 where head_dim is independently configured.
86+
"""
87+
target_config = ctx.target.config
88+
89+
head_num = target_config.num_attention_heads
90+
num_query_groups = target_config.num_key_value_heads
91+
heads_per_group = head_num // num_query_groups
92+
head_size = target_config.head_dim
93+
qkv_total_dim = head_num + 2 * num_query_groups
94+
95+
linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, -1])
96+
hidden_size = linear_qkv.size(-1)
97+
q_slice = torch.cat(
98+
[
99+
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
100+
for i in range(num_query_groups)
101+
]
102+
)
103+
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
104+
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
105+
106+
q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu()
107+
k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu()
108+
v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu()
109+
110+
return q_proj, k_proj, v_proj
111+
112+
113+
def convert_qwen3_hf_to_te(model_hf: Qwen3ForCausalLM, **config_kwargs) -> NVQwen3ForCausalLM:
114+
"""Convert a Hugging Face model to a Transformer Engine model.
115+
116+
Args:
117+
model_hf (nn.Module): The Hugging Face model.
118+
**config_kwargs: Additional configuration kwargs to be passed to NVQwen3Config.
119+
120+
Returns:
121+
nn.Module: The Transformer Engine model.
122+
"""
123+
te_config = NVQwen3Config(**model_hf.config.to_dict(), **config_kwargs)
124+
with torch.device("meta"):
125+
model_te = NVQwen3ForCausalLM(te_config)
126+
127+
if model_hf.config.tie_word_embeddings:
128+
state_dict_ignored_entries = ["lm_head.weight"]
129+
else:
130+
state_dict_ignored_entries = []
131+
132+
output_model = state.apply_transforms(
133+
model_hf,
134+
model_te,
135+
mapping,
136+
[
137+
state.state_transform(
138+
source_key=(
139+
"model.layers.*.self_attn.q_proj.weight",
140+
"model.layers.*.self_attn.k_proj.weight",
141+
"model.layers.*.self_attn.v_proj.weight",
142+
),
143+
target_key="model.layers.*.self_attention.layernorm_qkv.weight",
144+
fn=_merge_qkv,
145+
),
146+
state.state_transform(
147+
source_key=(
148+
"model.layers.*.mlp.gate_proj.weight",
149+
"model.layers.*.mlp.up_proj.weight",
150+
),
151+
target_key="model.layers.*.layernorm_mlp.fc1_weight",
152+
fn=state.TransformFns.merge_fc1,
153+
),
154+
],
155+
state_dict_ignored_entries=state_dict_ignored_entries,
156+
)
157+
158+
output_model.model.rotary_emb.inv_freq = model_hf.model.rotary_emb.inv_freq.clone()
159+
160+
return output_model
161+
162+
163+
def convert_qwen3_te_to_hf(model_te: NVQwen3ForCausalLM, **config_kwargs) -> Qwen3ForCausalLM:
164+
"""Convert a Transformer Engine model to a Hugging Face model.
165+
166+
Args:
167+
model_te (nn.Module): The Transformer Engine model.
168+
**config_kwargs: Additional configuration kwargs to be passed to Qwen3Config.
169+
170+
Returns:
171+
nn.Module: The Hugging Face model.
172+
"""
173+
# Filter out keys from model_te.config that are not valid Qwen3Config attributes
174+
te_config_dict = model_te.config.to_dict()
175+
valid_keys = set(inspect.signature(Qwen3Config.__init__).parameters)
176+
filtered_config = {k: v for k, v in te_config_dict.items() if k in valid_keys}
177+
hf_config = Qwen3Config(**filtered_config, **config_kwargs)
178+
179+
with torch.device("meta"):
180+
model_hf = Qwen3ForCausalLM(hf_config)
181+
182+
output_model = state.apply_transforms(
183+
model_te,
184+
model_hf,
185+
reverse_mapping,
186+
[
187+
state.state_transform(
188+
source_key="model.layers.*.self_attention.layernorm_qkv.weight",
189+
target_key=(
190+
"model.layers.*.self_attn.q_proj.weight",
191+
"model.layers.*.self_attn.k_proj.weight",
192+
"model.layers.*.self_attn.v_proj.weight",
193+
),
194+
fn=_split_qkv,
195+
),
196+
state.state_transform(
197+
source_key="model.layers.*.layernorm_mlp.fc1_weight",
198+
target_key=(
199+
"model.layers.*.mlp.gate_proj.weight",
200+
"model.layers.*.mlp.up_proj.weight",
201+
),
202+
fn=state.TransformFns.split_fc1,
203+
),
204+
],
205+
state_dict_ignored_entries=model_hf._tied_weights_keys,
206+
)
207+
208+
output_model.model.rotary_emb.inv_freq = model_te.model.rotary_emb.inv_freq.clone()
209+
output_model.tie_weights()
210+
211+
return output_model
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: LicenseRef-Apache2
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Create a Qwen3 checkpoint for export.
17+
18+
This script saves a randomly initialized Qwen3 model with TransformerEngine layers.
19+
"""
20+
21+
import json
22+
import shutil
23+
from pathlib import Path
24+
25+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
26+
27+
import convert
28+
from modeling_qwen3_te import AUTO_MAP
29+
30+
31+
def export_hf_checkpoint(tag: str, export_path: Path):
32+
"""Export a Hugging Face checkpoint to a Transformer Engine checkpoint.
33+
34+
Args:
35+
tag: The tag of the checkpoint to export.
36+
export_path: The parent path to export the checkpoint to.
37+
"""
38+
model_hf = AutoConfig.from_pretrained(tag)
39+
model_hf = AutoModelForCausalLM.from_config(model_hf)
40+
41+
model_te = convert.convert_qwen3_hf_to_te(model_hf)
42+
model_te.save_pretrained(export_path)
43+
44+
tokenizer = AutoTokenizer.from_pretrained(tag)
45+
tokenizer.save_pretrained(export_path)
46+
47+
# Patch the config
48+
with open(export_path / "config.json", "r") as f:
49+
config = json.load(f)
50+
51+
config["auto_map"] = AUTO_MAP
52+
53+
with open(export_path / "config.json", "w") as f:
54+
json.dump(config, f, indent=2, sort_keys=True)
55+
56+
shutil.copy("modeling_qwen3_te.py", export_path / "modeling_qwen3_te.py")
57+
58+
59+
if __name__ == "__main__":
60+
export_hf_checkpoint("Qwen/Qwen3-0.6B", Path("checkpoint_export"))

0 commit comments

Comments
 (0)