-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathprefix.py
More file actions
196 lines (166 loc) · 8.44 KB
/
prefix.py
File metadata and controls
196 lines (166 loc) · 8.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
import torch
from torch import nn
def find_module(root_module: nn.Module, key: str):
"""
Find a module with a specific name in a Transformer model
From OpenDelta https://github.com/thunlp/OpenDelta
"""
sub_keys = key.split(".")
parent_module = root_module
for sub_key in sub_keys[:-1]:
parent_module = getattr(parent_module, sub_key)
module = getattr(parent_module, sub_keys[-1])
return parent_module, sub_keys[-1], module
def attn_forward_hook(self, *args, **kwargs):
"""
Replace the original attention forward with this to enable prefix
"""
def _expand_bsz(x, bsz):
x = x.reshape(x.size(0), self.num_heads, -1).transpose(0,1) # (num_prefix, hidden) -> (num_head, num_prefix, hidden/num_head)
x = x.unsqueeze(0).expand(bsz, *x.shape) # -> (bsz, num_head, num_prefix, hidden/num_head)
return x
if "hidden_states" in kwargs:
hidden_states = kwargs["hidden_states"]
else:
hidden_states = args[0]
bsz = hidden_states.size(0)
if 'past_key_value' not in kwargs or kwargs['past_key_value'] is None:
if self.reparam:
prefix_keys = self.prefix_mlp_keys(self.prefix_input_embeds)
prefix_values = self.prefix_mlp_values(self.prefix_input_embeds)
else:
prefix_keys, prefix_values = self.prefix_keys, self.prefix_values
kwargs['past_key_value'] = (_expand_bsz(prefix_keys, bsz), _expand_bsz(prefix_values, bsz))
if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None:
am = kwargs['attention_mask']
kwargs['attention_mask'] = torch.cat([-torch.zeros((*am.shape[:-1], self.num_prefix), dtype=am.dtype, device=am.device), am], dim=-1)
elif len(args) > 1: # attention mask is passed via positional argument
am = args[1]
am = torch.cat([-torch.zeros((*am.shape[:-1], self.num_prefix), dtype=am.dtype, device=am.device), am], dim=-1)
args = (args[0], am) + args[2:]
return self.original_forward(*args, **kwargs)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs):
"""
Replace the original "prepare_inputs_for_generation" with this to pass prefix correctly
"""
original_input_len = input_ids.size(-1)
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
if past_key_values is not None:
# Check if we should add extra to attention mask
if past_key_values[0][0].size(2) != attention_mask.size(1) - 1:
num_prefix = past_key_values[0][0].size(2) - (attention_mask.size(1) - 1)
attention_mask = torch.cat([torch.ones((attention_mask.size(0), num_prefix), dtype=attention_mask.dtype, device=attention_mask.device), attention_mask], dim=-1)
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
class PrefixTuning:
def __init__(self, model, num_prefix, reparam=True, embed_dim=512, mid_dim=512, float16=False, init_by_real_act=False):
"""
Inputs:
num_prefix: number of prefix tokens
reparam: use reparameterization trick (not used in MeZO)
embed_dim, mid_dim: hyperparameters for reparameterization trick (not used in MeZO)
float15: whether the model parameters are float15
init_by_real_act: init prefix tokens by real activations
"""
self.model = model
self.num_prefix = num_prefix
self.hidden_dim = model.config.hidden_size
self.float16 = float16
# Reparameterization
self.reparam = reparam
self.embed_dim = embed_dim
self.mid_dim = mid_dim
input_embeds = None # For reparameterization
if model.config.model_type == "opt":
attention_name = "attn"
first_layer_name = "layers.0"
layer_name = "layers."
elif model.config.model_type == "roberta":
attention_name = "attention"
first_layer_name = "layer.0"
layer_name = "layer."
else:
raise NotImplementedError
if init_by_real_act:
# Initialize prefix with real words' activations
assert not reparam
# Randomly sample input tokens
input_tokens = torch.randint(low=0, high=model.config.vocab_size, size=(1, num_prefix), dtype=torch.long).cuda()
if model.config.model_type == "opt":
with torch.no_grad():
# Get the real activations
real_key_values = model(input_ids=input_tokens, use_cache=True).past_key_values
else:
raise NotImplementedError
# Insert prefix
for key, _ in model.named_modules():
if key[-len(attention_name):] == attention_name:
layer_id = int(key.split(layer_name)[1].split(".")[0])
logger.info(f"Inject prefix to: {key}")
_, _, attn = find_module(model, key)
# Replace the old forward functions
attn.original_forward = attn.forward
attn.forward = attn_forward_hook.__get__(attn, type(attn))
if not hasattr(attn, "num_heads"):
attn.num_heads = model.config.num_attention_heads
first = first_layer_name in key
self.add_prefix(attn, first=first, input_embeds=input_embeds)
if first and self.reparam:
input_embeds = attn.prefix_input_embeds
if init_by_real_act:
logger.info(f"Reinitialize with actual activation: {key} (layer {layer_id})")
keys = real_key_values[layer_id][0].squeeze(0).transpose(0, 1).reshape(num_prefix, -1)
values = real_key_values[layer_id][1].squeeze(0).transpose(0, 1).reshape(num_prefix, -1)
attn.prefix_keys.data = keys.to(attn.prefix_keys.data.device)
attn.prefix_values.data = values.to(attn.prefix_values.data.device)
# Freeze non-prefix parameters
for n, p in model.named_parameters():
if "prefix" not in n:
p.requires_grad = False
# Replace the old prepare_inputs_for_generation function
model.prepare_inputs_for_generation = prepare_inputs_for_generation.__get__(model, type(model))
def add_prefix(self, module, first, input_embeds=None):
device = module.k_proj.weight.data.device
module.num_prefix = self.num_prefix
module.reparam = self.reparam
if self.reparam:
if first:
# For the first layer we inject the embeddings
logger.info("For prefix+reparameterization, inject the embeddings in the first layer.")
module.prefix_input_embeds = nn.Parameter(torch.randn(self.num_prefix, self.embed_dim, device=device, dtype=self.model.dtype), requires_grad=True)
else:
assert input_embeds is not None
module.prefix_input_embeds = input_embeds
module.prefix_mlp_keys = nn.Sequential(
nn.Linear(self.embed_dim, self.mid_dim),
nn.Tanh(),
nn.Linear(self.mid_dim, self.hidden_dim)
).to(device)
module.prefix_mlp_values = nn.Sequential(
nn.Linear(self.embed_dim, self.mid_dim),
nn.Tanh(),
nn.Linear(self.mid_dim, self.hidden_dim)
).to(device)
if self.float16:
module.prefix_mlp_keys = module.prefix_mlp_keys.half()
module.prefix_mlp_values = module.prefix_mlp_values.half()
else:
module.prefix_keys = nn.Parameter(torch.randn(self.num_prefix, self.hidden_dim, device=device, dtype=self.model.dtype), requires_grad=True)
module.prefix_values = nn.Parameter(torch.randn(self.num_prefix, self.hidden_dim, device=device, dtype=self.model.dtype), requires_grad=True)