-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathqwen3_model_merging_experiments.py
More file actions
344 lines (260 loc) · 11.6 KB
/
qwen3_model_merging_experiments.py
File metadata and controls
344 lines (260 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
# -*- coding: utf-8 -*-
"""Qwen3 Model Merging Experiments.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1PcBodLy5BmW58v24GaeMAdC0lyLe3hvo
# Qwen3-0.6B Model Merging
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re
import math
from collections import OrderedDict
from tqdm import tqdm
import os
"""Note - I found that the Token/Vocab of both the below models and the base models are same, making it simpler to merge the model. I will be use saving/loading base model's tokenizer"""
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device {device}...")
# Load both models
model_med = AutoModelForCausalLM.from_pretrained(
"suayptalha/Qwen3-0.6B-Medical-Expert",
torch_dtype=torch.float16,
trust_remote_code=True
)
model_code = AutoModelForCausalLM.from_pretrained(
"suayptalha/Qwen3-0.6B-Code-Expert",
torch_dtype=torch.float16,
trust_remote_code=True
)
model_med.to(device)
model_code.to(device)
# Helper function to save Merged models
import gc
def cleanup_model(model, verbose=True):
# Clear memory post merging
device = next(model.parameters()).device
if verbose:
print(f"Cleaning up model from {device}...")
# Delete model
del model
# Run garbage collection
gc.collect()
if device.type == "cuda":
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
if verbose:
print("GPU memory freed.")
else:
if verbose:
print("CPU memory cleaned up.")
def save_merged_qwen_model(merged_state_dict, save_dir, device="cuda"):
print("Loading base Qwen3-0.6B model...")
base_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-0.6B",
torch_dtype=torch.float16
).to(device)
print("Loading merged weights into base model...")
base_model.load_state_dict(merged_state_dict)
print(f"Saving merged model to: {save_dir}...")
os.makedirs(save_dir, exist_ok=True)
base_model.save_pretrained(save_dir)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B", trust_remote_code=True)
tokenizer.save_pretrained(save_dir)
# Clear memory
cleanup_model(base_model)
print(f"Merged model saved at: {save_dir}")
"""## Simple Linear Merge"""
# Perform model merging
alpha = 0.5 # Giving equal importance to weights of both the models
merged_state_dict = {}
print("Merging weights...")
# Iterate over the state dict and merge weights
for key in tqdm(model_med.state_dict().keys(),
desc="Merging model weights : Simple Linear"):
merged_state_dict[key] = (
alpha * model_med.state_dict()[key] +
(1 - alpha) * model_code.state_dict()[key]
)
# Call helper function to save the model
save_merged_qwen_model(merged_state_dict=merged_state_dict,
save_dir="merged-qwen3-0.6B-medcode-simple-merge",
device=device)
"""# Layer Wise Merge - Using different alpha values
Intuition: Why layer-wise merging works
In transformer-based LLMs:
* Lower layers (early layers) mostly capture generic, low-level
features like token embeddings, syntactic patterns, etc.
* Middle layers start mixing in more task-relevant abstractions.
* Higher layers (later layers) encode more domain-specific or task-specific knowledge (e.g., answering a medical query vs. writing code).
So when merging a general-domain expert (e.g., code) with a domain-specific expert (e.g., medical):
* You don’t want to overwrite early layers too aggressively — since these are more generalizable across domains.
* You do want to prioritize the domain expert's higher layers, since those likely contain more task-specific knowledge.
Why sigmoid instead of linear?
Let’s compare:
* Linear Merge
A gradual, constant transition from alpha_min to alpha_max.
* Good, but too uniform — assumes equal importance per layer, which isn't always the case.
Sigmoid Merge
* Transitions slowly at first, then rapidly in the middle, and flattens out again at the end.
* Matches how information specialization increases non-linearly across layers.
* Early layers: Keep mostly from model_code (low alpha)
* Middle layers: Smooth shift in dominance
* Later layers: Strongly favor model_med (high alpha)
This avoids sharp transitions or uniform interpolation and instead creates a smooth knowledge handover from code to medical expertise.
Benefits of sigmoid curve:
* Mimics how knowledge hierarchy is built in transformers
* Helps avoid catastrophic interference from one model overwriting another
* Encourages better generalization in merged models
* Often empirically gives better performance on multi-domain or domain-adapted tasks
"""
# Recall how a Sigmoid curve looks like, the same curve is followed here.
def merge_models_layerwise_smooth(model_med, model_code, curve_type="sigmoid",
alpha_min=0.2, alpha_max=0.9):
"""
Merges model_med and model_code using a smooth, layer-wise alpha interpolation.
Params:
- model_med: torch.nn.Module with medical weights
- model_code: torch.nn.Module with code weights
- curve_type: 'linear' or 'sigmoid'
- alpha_min: alpha for early layers (favor code)
- alpha_max: alpha for later layers (favor medical)
Returns:
- merged_state_dict: new merged OrderedDict
"""
med_sd = model_med.state_dict()
code_sd = model_code.state_dict()
merged_sd = OrderedDict()
# Detect number of transformer blocks
layer_pattern = re.compile(r'model\.layers\.(\d+)\.')
layer_ids = sorted({int(layer_pattern.search(k).group(1)) for k in med_sd if layer_pattern.search(k)})
num_layers = max(layer_ids) + 1
def get_alpha_for_layer(layer_id):
if curve_type == "linear":
return alpha_min + (alpha_max - alpha_min) * (layer_id / (num_layers - 1))
elif curve_type == "sigmoid":
x = layer_id / (num_layers - 1) * 12 - 6 # scale to [-6, 6]
sigmoid = 1 / (1 + math.exp(-x))
return alpha_min + sigmoid * (alpha_max - alpha_min)
else:
raise ValueError("Unsupported curve_type. Choose 'linear' or 'sigmoid'.")
for key in tqdm(med_sd.keys(),
desc="Merging model weights : Sigmoid"):
match = layer_pattern.search(key)
if match:
layer_id = int(match.group(1))
alpha = get_alpha_for_layer(layer_id)
elif "embed_tokens" in key:
alpha = alpha_min
elif "lm_head" in key or "norm" in key:
alpha = alpha_max
else:
alpha = 0.5 # Fall back value
merged_sd[key] = alpha * med_sd[key] + (1 - alpha) * code_sd[key]
return merged_sd
merged_sd = merge_models_layerwise_smooth(
model_med, model_code,
curve_type="sigmoid",
alpha_min=0.2,
alpha_max=0.9
)
# Call helper function to save the model
save_merged_qwen_model(merged_state_dict=merged_sd,
save_dir="merged-qwen3-0.6B-medcode-weighted-merge",
device=device)
"""# Greedy Merging backed by evaluation"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from difflib import SequenceMatcher
from copy import deepcopy
from tqdm import tqdm
MODEL_A_ID = "suayptalha/Qwen3-0.6B-Code-Expert"
MODEL_B_ID = "suayptalha/Qwen3-0.6B-Medical-Expert"
BASE_MODEL_ID = "Qwen/Qwen3-0.6B"
MERGED_OUTPUT_PATH = "greedy-merged-with-base-qwen3-0.6B"
MAX_NEW_TOKENS = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device {DEVICE}...")
# LOAD MODELS & TOKENIZER
print("Loading models and tokenizer...")
model_a = AutoModelForCausalLM.from_pretrained(MODEL_A_ID, torch_dtype=torch.float16).to(DEVICE)
model_b = AutoModelForCausalLM.from_pretrained(MODEL_B_ID, torch_dtype=torch.float16).to(DEVICE)
model_base = AutoModelForCausalLM.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.float16).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True)
# COPY MODEL FOR MERGING
merged_model = deepcopy(model_a)
n_layers = len(model_a.model.layers)
# Eval questions (5 medical + 5 code)
qa_pairs = [
("What is the normal range of hemoglobin in adult males?", "13.8 to 17.2 grams per deciliter."),
("Name a common symptom of diabetes.", "Frequent urination."),
("What is the function of platelets in blood?", "They help in blood clotting."),
("What does BMI stand for in medical terms?", "Body Mass Index."),
("Which virus causes chickenpox?", "Varicella-zoster virus."),
("How do you reverse a linked list in Python?", "Use a loop or recursion to reverse the pointers."),
("What is the difference between a list and a tuple in Python?", "Lists are mutable, tuples are immutable."),
("What is the purpose of the 'self' keyword in Python classes?", "'self' refers to the instance of the class."),
("How does a for loop differ from a while loop?", "A for loop iterates over a sequence; a while loop runs based on a condition."),
("What does 'None' represent in Python?", "It represents the absence of a value or a null value.")
]
def generate_answer(model, prompt):
messages = [
{
"role": "system",
"content": (
"You are a helpful and concise assistant. Answer the user's query clearly, accurately, "
"and as briefly as possible while ensuring the response remains complete and informative. "
"Avoid unnecessary elaboration or repetition."
)
},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=MAX_NEW_TOKENS
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
try:
index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
index = 0
content = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
return content
def score_answer(generated, reference):
return SequenceMatcher(None, generated.strip().lower(), reference.strip().lower()).ratio()
def evaluate_model(model, qa_pairs):
model.eval()
total_score = 0
with torch.no_grad():
for q, ref in qa_pairs:
generated = generate_answer(model, q)
score = score_answer(generated, ref)
total_score += score
return total_score / len(qa_pairs)
# Greedy Layer selection
# If you don't want to pick layer from Base model, just remove 'base' key from nested for loop
print(f"\nStarting Greedy Merge with base model across {n_layers} layers...\n")
for i in tqdm(range(n_layers)):
best_score = -1
best_model_name = None
best_layer = None
for model_name, candidate_model in zip(['A', 'B', 'Base'], [model_a, model_b, model_base]):
merged_model.model.layers[i] = deepcopy(candidate_model.model.layers[i])
score = evaluate_model(merged_model, qa_pairs)
if score > best_score:
best_score = score
best_model_name = model_name
best_layer = deepcopy(candidate_model.model.layers[i])
merged_model.model.layers[i] = best_layer
print(f"Layer {i:02d}: Selected from model {best_model_name} (score: {best_score:.3f})")
# SAVE
print("Saving merged model...")
merged_model.save_pretrained(MERGED_OUTPUT_PATH)
tokenizer.save_pretrained(MERGED_OUTPUT_PATH)
print(f"Merged model saved to: {MERGED_OUTPUT_PATH}")