-
Notifications
You must be signed in to change notification settings - Fork 495
Expand file tree
/
Copy pathlora.py
More file actions
246 lines (188 loc) · 8.46 KB
/
lora.py
File metadata and controls
246 lines (188 loc) · 8.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import math
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
import PIL
import torch
import torch.nn.functional as F
import torch.nn as nn
class LoraInjectedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=False, r=4):
super().__init__()
if r > min(in_features, out_features):
raise ValueError(
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
)
self.linear = nn.Linear(in_features, out_features, bias)
self.lora_down = nn.Linear(in_features, r, bias=False)
self.lora_up = nn.Linear(r, out_features, bias=False)
self.scale = 1.0
nn.init.normal_(self.lora_down.weight, std=1 / r**2)
nn.init.zeros_(self.lora_up.weight)
def forward(self, input):
return self.linear(input) + self.lora_up(self.lora_down(input)) * self.scale
def inject_trainable_lora(
model: nn.Module,
target_replace_module: List[str] = ["CrossAttention", "Attention"],
r: int = 4,
):
"""
inject lora into model, and returns lora parameter groups.
"""
require_grad_params = []
names = []
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for name, _child_module in _module.named_modules():
if _child_module.__class__.__name__ == "Linear":
weight = _child_module.weight
bias = _child_module.bias
_tmp = LoraInjectedLinear(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r,
)
_tmp.linear.weight = weight
if bias is not None:
_tmp.linear.bias = bias
# switch the module
_module._modules[name] = _tmp
require_grad_params.append(
_module._modules[name].lora_up.parameters()
)
require_grad_params.append(
_module._modules[name].lora_down.parameters()
)
_module._modules[name].lora_up.weight.requires_grad = True
_module._modules[name].lora_down.weight.requires_grad = True
names.append(name)
return require_grad_params, names
def extract_lora_ups_down(model, target_replace_module=["CrossAttention", "Attention"]):
no_injection = True
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for _child_module in _module.modules():
if _child_module.__class__.__name__ == "LoraInjectedLinear":
no_injection = False
yield (_child_module.lora_up, _child_module.lora_down)
if no_injection
raise ValueError("No lora injected.")
def save_lora_weight(
model, path="./lora.pt", target_replace_module=["CrossAttention", "Attention"]
):
weights = []
for _up, _down in extract_lora_ups_down(
model, target_replace_module=target_replace_module
):
weights.append(_up.weight)
weights.append(_down.weight)
torch.save(weights, path)
def save_lora_as_json(model, path="./lora.json"):
weights = []
for _up, _down in extract_lora_ups_down(model):
weights.append(_up.weight.detach().cpu().numpy().tolist())
weights.append(_down.weight.detach().cpu().numpy().tolist())
import json
with open(path, "w") as f:
json.dump(weights, f)
def weight_apply_lora(
model, loras, target_replace_module=["CrossAttention", "Attention"], alpha=1.0
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for _child_module in _module.modules():
if _child_module.__class__.__name__ == "Linear":
weight = _child_module.weight
up_weight = loras.pop(0).detach().to(weight.device)
down_weight = loras.pop(0).detach().to(weight.device)
# W <- W + U * D
weight = weight + alpha * (up_weight @ down_weight).type(
weight.dtype
)
_child_module.weight = nn.Parameter(weight)
def monkeypatch_lora(
model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for name, _child_module in _module.named_modules():
if _child_module.__class__.__name__ == "Linear":
weight = _child_module.weight
bias = _child_module.bias
_tmp = LoraInjectedLinear(
_child_module.in_features,
_child_module.out_features,
_child_module.bias is not None,
r=r,
)
_tmp.linear.weight = weight
if bias is not None:
_tmp.linear.bias = bias
# switch the module
_module._modules[name] = _tmp
up_weight = loras.pop(0)
down_weight = loras.pop(0)
_module._modules[name].lora_up.weight = nn.Parameter(
up_weight.type(weight.dtype)
)
_module._modules[name].lora_down.weight = nn.Parameter(
down_weight.type(weight.dtype)
)
_module._modules[name].to(weight.device)
def monkeypatch_replace_lora(
model, loras, target_replace_module=["CrossAttention", "Attention"], r: int = 4
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for name, _child_module in _module.named_modules():
if _child_module.__class__.__name__ == "LoraInjectedLinear":
weight = _child_module.linear.weight
bias = _child_module.linear.bias
_tmp = LoraInjectedLinear(
_child_module.linear.in_features,
_child_module.linear.out_features,
_child_module.linear.bias is not None,
r=r,
)
_tmp.linear.weight = weight
if bias is not None:
_tmp.linear.bias = bias
# switch the module
_module._modules[name] = _tmp
up_weight = loras.pop(0)
down_weight = loras.pop(0)
_module._modules[name].lora_up.weight = nn.Parameter(
up_weight.type(weight.dtype)
)
_module._modules[name].lora_down.weight = nn.Parameter(
down_weight.type(weight.dtype)
)
_module._modules[name].to(weight.device)
def monkeypatch_add_lora(
model,
loras,
target_replace_module=["CrossAttention", "Attention"],
alpha: float = 1.0,
beta: float = 1.0,
):
for _module in model.modules():
if _module.__class__.__name__ in target_replace_module:
for name, _child_module in _module.named_modules():
if _child_module.__class__.__name__ == "LoraInjectedLinear":
weight = _child_module.linear.weight
up_weight = loras.pop(0)
down_weight = loras.pop(0)
_module._modules[name].lora_up.weight = nn.Parameter(
up_weight.type(weight.dtype).to(weight.device) * alpha
+ _module._modules[name].lora_up.weight.to(weight.device) * beta
)
_module._modules[name].lora_down.weight = nn.Parameter(
down_weight.type(weight.dtype).to(weight.device) * alpha
+ _module._modules[name].lora_down.weight.to(weight.device)
* beta
)
_module._modules[name].to(weight.device)
def tune_lora_scale(model, alpha: float = 1.0):
for _module in model.modules():
if _module.__class__.__name__ == "LoraInjectedLinear":
_module.scale = alpha