@@ -74,16 +74,18 @@ def _split_qkv_bias(ctx: state.TransformCTX, qkv_bias: torch.Tensor):
7474 qkv_bias = qkv_bias .reshape (qkv_total_dim , head_size )
7575 q_slice = torch .cat (
7676 [
77- torch .arange ((heads_per_group + 2 ) * i , (heads_per_group + 2 ) * i + heads_per_group )
77+ torch .arange (
78+ (heads_per_group + 2 ) * i , (heads_per_group + 2 ) * i + heads_per_group , device = qkv_bias .device
79+ )
7880 for i in range (num_query_groups )
7981 ]
8082 )
81- k_slice = torch .arange (heads_per_group , qkv_total_dim , (heads_per_group + 2 ))
82- v_slice = torch .arange (heads_per_group + 1 , qkv_total_dim , (heads_per_group + 2 ))
83+ k_slice = torch .arange (heads_per_group , qkv_total_dim , (heads_per_group + 2 ), device = qkv_bias . device )
84+ v_slice = torch .arange (heads_per_group + 1 , qkv_total_dim , (heads_per_group + 2 ), device = qkv_bias . device )
8385
84- q_bias = qkv_bias [q_slice ].reshape (- 1 ). cpu ()
85- k_bias = qkv_bias [k_slice ].reshape (- 1 ). cpu ()
86- v_bias = qkv_bias [v_slice ].reshape (- 1 ). cpu ()
86+ q_bias = qkv_bias [q_slice ].reshape (- 1 )
87+ k_bias = qkv_bias [k_slice ].reshape (- 1 )
88+ v_bias = qkv_bias [v_slice ].reshape (- 1 )
8789
8890 return q_bias , k_bias , v_bias
8991
@@ -206,6 +208,11 @@ def convert_qwen2_te_to_hf(model_te: NVQwen2ForCausalLM, **config_kwargs) -> Qwe
206208 with torch .device ("meta" ):
207209 model_hf = Qwen2ForCausalLM (hf_config )
208210
211+ if model_hf .config .tie_word_embeddings :
212+ state_dict_ignored_entries = model_hf ._tied_weights_keys
213+ else :
214+ state_dict_ignored_entries = []
215+
209216 output_model = state .apply_transforms (
210217 model_te ,
211218 model_hf ,
@@ -241,10 +248,12 @@ def convert_qwen2_te_to_hf(model_te: NVQwen2ForCausalLM, **config_kwargs) -> Qwe
241248 fn = state .TransformFns .split_fc1 ,
242249 ),
243250 ],
244- state_dict_ignored_entries = model_hf . _tied_weights_keys ,
251+ state_dict_ignored_entries = state_dict_ignored_entries ,
245252 )
246253
247254 output_model .model .rotary_emb .inv_freq = model_te .model .rotary_emb .inv_freq .clone ()
248- output_model .tie_weights ()
255+
256+ if model_hf .config .tie_word_embeddings :
257+ output_model .tie_weights ()
249258
250259 return output_model
0 commit comments