Skip to content

Commit c91136b

Browse files
committed
addressing coderabbit review
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent b597c6b commit c91136b

8 files changed

Lines changed: 41 additions & 27 deletions

File tree

bionemo-recipes/models/qwen/convert_qwen2.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,18 @@ def _split_qkv_bias(ctx: state.TransformCTX, qkv_bias: torch.Tensor):
7474
qkv_bias = qkv_bias.reshape(qkv_total_dim, head_size)
7575
q_slice = torch.cat(
7676
[
77-
torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
77+
torch.arange(
78+
(heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group, device=qkv_bias.device
79+
)
7880
for i in range(num_query_groups)
7981
]
8082
)
81-
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
82-
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
83+
k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2), device=qkv_bias.device)
84+
v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2), device=qkv_bias.device)
8385

84-
q_bias = qkv_bias[q_slice].reshape(-1).cpu()
85-
k_bias = qkv_bias[k_slice].reshape(-1).cpu()
86-
v_bias = qkv_bias[v_slice].reshape(-1).cpu()
86+
q_bias = qkv_bias[q_slice].reshape(-1)
87+
k_bias = qkv_bias[k_slice].reshape(-1)
88+
v_bias = qkv_bias[v_slice].reshape(-1)
8789

8890
return q_bias, k_bias, v_bias
8991

@@ -206,6 +208,11 @@ def convert_qwen2_te_to_hf(model_te: NVQwen2ForCausalLM, **config_kwargs) -> Qwe
206208
with torch.device("meta"):
207209
model_hf = Qwen2ForCausalLM(hf_config)
208210

211+
if model_hf.config.tie_word_embeddings:
212+
state_dict_ignored_entries = model_hf._tied_weights_keys
213+
else:
214+
state_dict_ignored_entries = []
215+
209216
output_model = state.apply_transforms(
210217
model_te,
211218
model_hf,
@@ -241,10 +248,12 @@ def convert_qwen2_te_to_hf(model_te: NVQwen2ForCausalLM, **config_kwargs) -> Qwe
241248
fn=state.TransformFns.split_fc1,
242249
),
243250
],
244-
state_dict_ignored_entries=model_hf._tied_weights_keys,
251+
state_dict_ignored_entries=state_dict_ignored_entries,
245252
)
246253

247254
output_model.model.rotary_emb.inv_freq = model_te.model.rotary_emb.inv_freq.clone()
248-
output_model.tie_weights()
255+
256+
if model_hf.config.tie_word_embeddings:
257+
output_model.tie_weights()
249258

250259
return output_model

bionemo-recipes/models/qwen/export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def export_hf_checkpoint(tag: str, export_path: Path):
5353
with open(export_path / "config.json", "w") as f:
5454
json.dump(config, f, indent=2, sort_keys=True)
5555

56-
shutil.copy("modeling_qwen3_te.py", export_path / "modeling_qwen3_te.py")
56+
shutil.copy(Path(__file__).parent / "modeling_qwen3_te.py", export_path / "modeling_qwen3_te.py")
5757

5858

5959
if __name__ == "__main__":

bionemo-recipes/models/qwen/modeling_qwen2_te.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def forward(
212212
# attention backend, but it should be faster for the flash attention backend.
213213
assert attention_mask is not None, "Attention mask is required when packing BSHD inputs."
214214
batch_size = hidden_states.size(0)
215-
padded_seq_len = input_ids.size(1)
215+
padded_seq_len = input_ids.size(1) if input_ids is not None else hidden_states.size(1)
216216
hidden_states, indices, cu_seqlens, max_seqlen, _ = _unpad_input(hidden_states, attention_mask)
217217
kwargs["cu_seq_lens_q"] = kwargs["cu_seq_lens_k"] = cu_seqlens
218218
kwargs["max_length_q"] = kwargs["max_length_k"] = max_seqlen

bionemo-recipes/models/qwen/modeling_qwen3_te.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,10 @@ def _init_method(x):
139139
qk_norm_eps=config.rms_norm_eps,
140140
qk_norm_before_rope=True,
141141
window_size=(config.sliding_window, config.sliding_window)
142-
if config.layer_types[layer_idx] == "sliding_attention"
142+
if config.layer_types is not None
143+
and len(config.layer_types) > layer_idx
144+
and config.layer_types[layer_idx] == "sliding_attention"
145+
and config.sliding_window is not None
143146
else None,
144147
layer_number=layer_idx + 1,
145148
params_dtype=config.dtype,

bionemo-recipes/models/qwen/tests/common/test_modeling_common.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from abc import ABC, abstractmethod
2121
from dataclasses import dataclass
2222
from pathlib import Path
23-
from typing import Callable, Dict, List, Literal, Type
23+
from typing import Any, Callable, Dict, List, Literal, Type
2424

2525
import pytest
2626
import torch
@@ -987,8 +987,8 @@ def test_meta_fp8_init(self, fp8_recipe):
987987
self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True)
988988

989989
# ==================== Generation Tests (Autoregressive Models Only) ====================
990-
991-
def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
990+
@abstractmethod
991+
def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1) -> Any:
992992
"""Create inference params for KV-cache generation tests.
993993
994994
Autoregressive model tests must override this method to provide
@@ -1003,9 +1003,7 @@ def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_be
10031003
Returns:
10041004
HFInferenceParams instance with allocated memory.
10051005
"""
1006-
raise NotImplementedError(
1007-
"Autoregressive models must override _create_inference_params to provide model-specific HFInferenceParams."
1008-
)
1006+
pass
10091007

10101008
def test_generate_without_cache(self):
10111009
"""Test basic generation without KV-cache (BSHD, use_cache=False)."""
@@ -1040,7 +1038,7 @@ def test_generate_with_cache(self):
10401038
inputs = tokenizer(prompt, return_tensors="pt")
10411039
inputs = {k: v.to("cuda") for k, v in inputs.items()}
10421040

1043-
past_key_values = self._create_inference_params(config, batch_size=1)
1041+
past_key_values = self.create_inference_params(config, batch_size=1)
10441042

10451043
with torch.no_grad():
10461044
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
@@ -1064,7 +1062,7 @@ def test_generate_with_cache_batched(self):
10641062
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
10651063
inputs = {k: v.to("cuda") for k, v in inputs.items()}
10661064

1067-
past_key_values = self._create_inference_params(config, batch_size=2)
1065+
past_key_values = self.create_inference_params(config, batch_size=2)
10681066

10691067
with torch.no_grad():
10701068
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
@@ -1090,7 +1088,7 @@ def test_generate_with_cache_beam_search(self):
10901088
inputs = {k: v.to("cuda") for k, v in inputs.items()}
10911089

10921090
num_beams = 2
1093-
past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams)
1091+
past_key_values = self.create_inference_params(config, batch_size=2, num_beams=num_beams)
10941092

10951093
with torch.no_grad():
10961094
output_ids = model.generate(

bionemo-recipes/models/qwen/tests/test_modeling_qwen2_te.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def get_upstream_model_revision(self) -> str:
6767

6868
def get_tokenizer(self) -> PreTrainedTokenizer:
6969
"""Return the Qwen2 tokenizer."""
70-
tokenizer = AutoTokenizer.from_pretrained(self.get_upstream_model_id())
70+
tokenizer = AutoTokenizer.from_pretrained(
71+
self.get_upstream_model_id(), revision=self.get_upstream_model_revision()
72+
)
7173
if tokenizer.pad_token is None:
7274
tokenizer.pad_token = tokenizer.eos_token
7375
return tokenizer
@@ -152,7 +154,7 @@ def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_forma
152154

153155
# ==================== Qwen2-Specific Overrides ====================
154156

155-
def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
157+
def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
156158
"""Create HFInferenceParams for the given config.
157159
158160
Uses hidden_size // num_attention_heads for head_dim since Qwen2 does not

bionemo-recipes/models/qwen/tests/test_modeling_qwen3_te.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def get_upstream_model_revision(self) -> str:
6767

6868
def get_tokenizer(self) -> PreTrainedTokenizer:
6969
"""Return the Qwen3 tokenizer."""
70-
tokenizer = AutoTokenizer.from_pretrained(self.get_upstream_model_id())
70+
tokenizer = AutoTokenizer.from_pretrained(
71+
self.get_upstream_model_id(), revision=self.get_upstream_model_revision()
72+
)
7173
if tokenizer.pad_token is None:
7274
tokenizer.pad_token = tokenizer.eos_token
7375
return tokenizer
@@ -153,7 +155,7 @@ def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_forma
153155

154156
# ==================== Qwen3-Specific Overrides ====================
155157

156-
def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
158+
def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
157159
"""Create HFInferenceParams for the given config.
158160
159161
Uses config.head_dim (not hidden_size // num_attention_heads) since Qwen3

ci/scripts/check_copied_files.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
"bionemo-recipes/models/esm2/collator.py": [
3939
"bionemo-recipes/models/llama3/collator.py",
4040
"bionemo-recipes/models/mixtral/collator.py",
41-
"bionemo-recipes/models/qwen3/collator.py",
41+
"bionemo-recipes/models/qwen/collator.py",
4242
"bionemo-recipes/recipes/esm2_native_te/collator.py",
4343
"bionemo-recipes/recipes/llama3_native_te/collator.py",
4444
"bionemo-recipes/recipes/esm2_peft_te/collator.py",
@@ -47,7 +47,7 @@
4747
"bionemo-recipes/models/amplify/src/amplify/state.py",
4848
"bionemo-recipes/models/llama3/state.py",
4949
"bionemo-recipes/models/mixtral/state.py",
50-
"bionemo-recipes/models/qwen3/state.py",
50+
"bionemo-recipes/models/qwen/state.py",
5151
],
5252
"bionemo-recipes/models/llama3/modeling_llama_te.py": [
5353
"bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py",
@@ -62,7 +62,7 @@
6262
"bionemo-recipes/models/esm2/tests/common": [
6363
"bionemo-recipes/models/llama3/tests/common",
6464
"bionemo-recipes/models/mixtral/tests/common",
65-
"bionemo-recipes/models/qwen3/tests/common",
65+
"bionemo-recipes/models/qwen/tests/common",
6666
],
6767
}
6868

0 commit comments

Comments
 (0)