-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathutils.py
More file actions
306 lines (258 loc) · 10.9 KB
/
utils.py
File metadata and controls
306 lines (258 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import os
import torch
import random
import numpy as np
from transformers import set_seed, AutoTokenizer
import json
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator
import torch.nn as nn
def print_rank_0(msg, rank=None):
if rank is not None and rank <= 0:
print(msg)
elif is_rank_0():
print(msg)
def is_rank_0():
"""Check whether it is rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
return True
else:
return False
else:
return True
def to_device(batch, device):
output = {}
for k, v in batch.items():
try:
output[k] = v.to(device)
except:
output[k] = v
return output
class MovingAverage:
def __init__(self):
self.count = 0
self.total = 0
self.mean = 0
def update(self, num):
self.total += num
self.count += 1
self.mean = self.total / self.count
return self.mean
class ExponentialMovingAverage:
def __init__(self, alpha=0.9):
self.alpha = alpha
self.ema = None
def update(self, num):
prev_ema = num if self.ema is None else self.ema
self.ema = self.alpha * prev_ema + (1.0 - self.alpha) * num
return self.ema
def get(self):
return self.ema if self.ema is not None else 0.
def get_tokenizer(model_name_or_path, fast_tokenizer=True):
if "llama" in model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, fast_tokenizer=fast_tokenizer)
if tokenizer.pad_token is None:
# assert tokenizer.eos_token is not None
# tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.padding_side = 'right'
else:
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path, fast_tokenizer=fast_tokenizer)
tokenizer.pad_token = tokenizer.eos_token
# make sure tokenizer is right pad in our logic
tokenizer.padding_side = 'right'
return tokenizer
def load_hf_tokenizer(model_name_or_path,
fast_tokenizer=True,
add_special_tokens=None):
# Support loading from local path directly
if os.path.exists(model_name_or_path) and os.path.isdir(model_name_or_path):
# Directly load tokenizer from local path
tokenizer = get_tokenizer(model_name_or_path,
fast_tokenizer=fast_tokenizer)
else:
# Load from HuggingFace Hub or use original logic
tokenizer = get_tokenizer(model_name_or_path,
fast_tokenizer=fast_tokenizer)
if add_special_tokens is not None:
add_special_tokens = [add_special_tokens] if isinstance(add_special_tokens, str) \
else add_special_tokens
tokenizer.add_special_tokens(
{'additional_special_tokens': add_special_tokens})
return tokenizer
def save_hf_format(model, tokenizer, args, sub_folder=""):
# used to save huggingface format, so we can use it for hf.from_pretrained
model_to_save = model.module if hasattr(model, 'module') else model
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"
output_dir = os.path.join(args.output_dir, sub_folder)
os.makedirs(output_dir, exist_ok=True)
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)
save_dict = model_to_save.state_dict()
for key in list(save_dict.keys()):
if "lora" in key:
del save_dict[key]
torch.save(save_dict, output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_dir)
def set_random_seed(seed):
if seed is not None:
set_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
get_accelerator().manual_seed_all(seed)
def get_all_reduce_mean(tensor):
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.SUM)
tensor = tensor / torch.distributed.get_world_size()
return tensor
# This function is a modified version of code available in the from_pretrained API of HuggingFace Transformers
# The code is copied and modified from: https://github.com/huggingface/transformers/blob/5ee9693a1c77c617ebc43ef20194b6d3b674318e/src/transformers/modeling_utils.py#L498
# This function helps load a HF format checkpoint into a DeepSpeed wrapped model that has been sharded using ZeRO Stage 3
def load_state_dict_into_model(model_to_load=None,
state_dict=None,
start_prefix="",
zero_stage=0):
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: nn.Module, state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(
prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
if zero_stage == 3:
# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(
module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [
named_parameters[k] for k in state_dict.keys()
if k in named_parameters
]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather,
modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model_to_load, state_dict, prefix=start_prefix)
# Delete `state_dict` so it could be collected by GC earlier. Note that `state_dict` is a copy of the argument, so
# it's safe to delete it.
del state_dict
return error_msgs
def get_optimizer_grouped_parameters(
model,
weight_decay,
lora_lr=5e-4,
no_decay_name_list=[
"bias", "layer_norm.weight", "layernorm.weight", "norm.weight",
"ln_f.weight"
],
lora_name_list=["lora_right_weight", "lora_left_weight"],
):
optimizer_grouped_parameters = [
{
"params": [
p for n, p in model.named_parameters()
if (not any(nd in n.lower() for nd in no_decay_name_list)
and p.requires_grad and not any(nd in n.lower()
for nd in lora_name_list))
],
"weight_decay":
weight_decay,
},
{
"params": [
p for n, p in model.named_parameters()
if (not any(nd in n.lower() for nd in no_decay_name_list)
and p.requires_grad and any(nd in n.lower()
for nd in lora_name_list))
],
"weight_decay":
weight_decay,
"lr":
lora_lr
},
{
"params": [
p for n, p in model.named_parameters()
if (any(nd in n.lower()
for nd in no_decay_name_list) and p.requires_grad)
],
"weight_decay":
0.0,
},
]
non_empty_groups = []
for group in optimizer_grouped_parameters:
if group["params"]:
non_empty_groups.append(group)
return non_empty_groups
def _z3_params_to_fetch(param_list):
return [
p for p in param_list
if hasattr(p, 'ds_id') and p.ds_status == ZeroParamStatus.NOT_AVAILABLE
]
def moving_average(model, model_ema, beta=0.992, device=None, zero_stage=0):
zero_stage_3 = (zero_stage == 3)
with torch.no_grad():
for param, param_ema in zip(model.parameters(),
model_ema.parameters()):
# TODO: use prefiltering for efficiency
params_to_fetch = _z3_params_to_fetch([param, param_ema
]) if zero_stage_3 else []
should_gather_param = len(params_to_fetch) > 0
with deepspeed.zero.GatheredParameters(
params_to_fetch, enabled=should_gather_param):
data = param.data
if device is not None:
data = data.to(device)
param_ema.data.copy_(torch.lerp(data, param_ema.data, beta))
def save_zero_three_model(model_ema, global_rank, save_dir, zero_stage=0):
zero_stage_3 = (zero_stage == 3)
os.makedirs(save_dir, exist_ok=True)
WEIGHTS_NAME = "pytorch_model.bin"
output_model_file = os.path.join(save_dir, WEIGHTS_NAME)
model_to_save = model_ema.module if hasattr(model_ema,
'module') else model_ema
if not zero_stage_3:
if global_rank == 0:
torch.save(model_to_save.state_dict(), output_model_file)
else:
output_state_dict = {}
for k, v in model_to_save.named_parameters():
if hasattr(v, 'ds_id'):
with deepspeed.zero.GatheredParameters(_z3_params_to_fetch([v
]),
enabled=zero_stage_3):
v_p = v.data.cpu()
else:
v_p = v.cpu()
if global_rank == 0 and "lora" not in k:
output_state_dict[k] = v_p
if global_rank == 0:
torch.save(output_state_dict, output_model_file)
del output_state_dict