Skip to content

Commit ab84c96

Browse files
authored
Support loading models trained with different model_parallel_world_size. (#16)
* temp save * quick fix of demo memory issue * Refactor tensor creation dtype / device control. This commit makes two changes during model creation: 1. Decouples promote_trainable_params_to_fp32 from model __init__. This is to avoid casting to fp32 to save memory in inference-only mode (#4). 2. Use a context manager to manage default tensor type change. In the previous version, the default tensor type is reset to torch.FloatTensor after creating the vision model, which is technically incorrect and should be the previous default tensor type instead. We implement our own context manager because the official context managers seem to be incomplete at this time (PyTorch 2.0.1): No dtype manager is provided and set_default_device is ineffective to the torch.Tensor calls which are used in fairscale. * Change CLIP dtype management in llama.py It is probably safer to keep CLIP at its original precision (e.g., fp16) regardless of the autocast setting: Some casting (e.g., from fp16 to bf16) may be lossy and can potentially harm the pre-trained model. Keep the changes to llama.py only at this moment since a lot of copy- pasted codes may be refactored in the future (#3). * Respect args.precision when saving checkpoints. * Support checkpoint merge Checkpoint merge is suported in misc/tensor_parallel.py. Merge requires that the checkpoint_mp_world_size % mp_world_size == 0. Support for split (i.e., when mp_world_size % checkpoint_mp_world_size == 0) and redistribute (for general mp_world_size and checkpoint_mp_world_size values) will be added in the future. Also changing multi_turn demo to use the new loading function with merge support. * move printing trainable params * move training model creation back to cpu Closes #15, #13
1 parent 89425cd commit ab84c96

11 files changed

Lines changed: 347 additions & 58 deletions

File tree

accessory/demos/multi_turn.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515
import gradio as gr
1616

17-
from util.misc import setup_for_distributed, load_pretrained
17+
from util.misc import setup_for_distributed
18+
from util.tensor_parallel import load_tensor_parallel_model
19+
from util.tensor_type import default_tensor_type
1820
from model.meta import MetaModel
1921
from data.conversation.lib import conv_templates, SeparatorStyle
2022

@@ -50,14 +52,18 @@ def model_worker(
5052
# set the print behavior.
5153
setup_for_distributed(rank == 0)
5254

53-
torch.set_default_tensor_type(torch.cuda.HalfTensor)
54-
model = MetaModel(
55-
args.llama_type, args.llama_config, args.tokenizer_path,
56-
with_visual=False, max_seq_len=args.model_max_seq_len,
57-
)
58-
torch.set_default_tensor_type(torch.FloatTensor)
55+
target_dtype = {
56+
"bf16": torch.bfloat16,
57+
"fp16": torch.float16,
58+
}[args.dtype]
59+
with default_tensor_type(dtype=target_dtype, device="cuda"):
60+
model = MetaModel(
61+
args.llama_type, args.llama_config, args.tokenizer_path,
62+
with_visual=False, max_seq_len=args.model_max_seq_len,
63+
)
64+
model.eval()
5965
print(f"Loading pretrained weights from {args.pretrained_path}")
60-
load_pretrained(args.pretrained_path, args.pretrained_type, model)
66+
load_tensor_parallel_model(model, args.pretrained_path, args.pretrained_type)
6167
print(f"Model = {str(model)}")
6268

6369
barrier.wait()
@@ -67,22 +73,29 @@ def model_worker(
6773
for user, bot in chatbot:
6874
conv.append_message(conv.roles[0], user)
6975
conv.append_message(conv.roles[1], bot)
70-
# print(conv.get_prompt())
7176

72-
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
73-
for stream_response in model.stream_generate(
74-
conv.get_prompt(), None,
75-
max_gen_len, temperature, top_p
76-
):
77-
end_pos = stream_response['text'].find(conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2)
78-
if end_pos != -1:
79-
stream_response['text'] = stream_response['text'][:end_pos].rstrip()+"\n"
80-
stream_response['end_of_content'] = True
81-
if response_queue is not None:
82-
response_queue.put(stream_response)
83-
84-
if stream_response['end_of_content']:
85-
break
77+
for stream_response in model.stream_generate(
78+
conv.get_prompt(), None,
79+
max_gen_len, temperature, top_p
80+
):
81+
conv_sep = conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2
82+
end_pos = stream_response["text"].find(conv_sep)
83+
if end_pos != -1:
84+
stream_response["text"] = stream_response['text'][:end_pos].rstrip() + "\n"
85+
stream_response["end_of_content"] = True
86+
87+
# keep a few characters if not end_of_content to avoid sending part of conv_sep
88+
# before all of it is generated.
89+
if not stream_response["end_of_content"]:
90+
if len(stream_response["text"]) < len(conv_sep):
91+
continue
92+
stream_response["text"] = stream_response["text"][:-len(conv_sep)]
93+
94+
if response_queue is not None:
95+
response_queue.put(stream_response)
96+
97+
if stream_response["end_of_content"]:
98+
break
8699

87100

88101
def gradio_worker(
@@ -178,6 +191,8 @@ def undo(chatbot):
178191
help="A port used by the PyTorch distributed module to initialize.")
179192
parser.add_argument("--master_addr", type=str, default="127.0.0.1",
180193
help="An address used by the PyTorch distributed module to initialize.")
194+
parser.add_argument("--dtype", type=str, choices=["fp16", "bf16"], default="bf16",
195+
help="The dtype used for model weights and inference.")
181196
args = parser.parse_args()
182197

183198
# check and setup gpu_ids to use

accessory/main_finetune.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import util.misc as misc
3434
from util.misc import NativeScalerWithGradNormCount as NativeScaler
35+
from util.tensor_type import default_tensor_type, promote_trainable_params_to_fp32
3536
from model.meta import MetaModel
3637
from engine_finetune import train_one_epoch
3738
from torch.utils.data import Dataset
@@ -150,8 +151,16 @@ def main(args):
150151
dp_group = fs_init.get_data_parallel_group()
151152

152153
# define the model
153-
model = MetaModel(args.llama_type, args.llama_config,
154-
args.tokenizer_path, with_visual=not args.no_visual)
154+
mixed_precision_dtype = {
155+
"fp16": torch.float16,
156+
"bf16": torch.bfloat16,
157+
"tf32": torch.float32,
158+
}[args.precision]
159+
with default_tensor_type(dtype=mixed_precision_dtype, device="cpu"):
160+
model = MetaModel(args.llama_type, args.llama_config,
161+
args.tokenizer_path, with_visual=not args.no_visual)
162+
promote_trainable_params_to_fp32(model)
163+
misc.print_trainable_params(model)
155164
print(f"load pretrained from {args.pretrained_path}")
156165
misc.load_pretrained(args.pretrained_path, args.pretrained_type, model)
157166
print("Unwrapped Model = %s" % str(model))
@@ -160,11 +169,6 @@ def main(args):
160169
if args.resume:
161170
misc.resume_stage1(args, model_without_FSDP=model)
162171

163-
mixed_precision_dtype = {
164-
"fp16": torch.float16,
165-
"bf16": torch.bfloat16,
166-
"tf32": torch.float32,
167-
}[args.precision]
168172
TransformerBlock = type(model.llma.layers[0])
169173
# ignored_named_parameters = {name: param for name, param in model.named_parameters() if not param.requires_grad}
170174
# print(ignored_named_parameters.keys())

accessory/main_pretrain.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
import util.misc as misc
3434
from util.misc import NativeScalerWithGradNormCount as NativeScaler
35+
from util.tensor_type import default_tensor_type, promote_trainable_params_to_fp32
3536
from model.meta import MetaModel
3637
from engine_pretrain import train_one_epoch, val_one_epoch
3738
from torch.utils.data import Dataset
@@ -147,8 +148,16 @@ def main(args):
147148
dp_group = fs_init.get_data_parallel_group()
148149

149150
# define the model
150-
model = MetaModel(args.llama_type, args.llama_config,
151-
args.tokenizer_path, with_visual=False)
151+
mixed_precision_dtype = {
152+
"fp16": torch.float16,
153+
"bf16": torch.bfloat16,
154+
"tf32": torch.float32,
155+
}[args.precision]
156+
with default_tensor_type(dtype=mixed_precision_dtype, device="cpu"):
157+
model = MetaModel(args.llama_type, args.llama_config,
158+
args.tokenizer_path, with_visual=False)
159+
promote_trainable_params_to_fp32(model)
160+
misc.print_trainable_params(model)
152161
if args.pretrained_path:
153162
print(f"load pretrained from {args.pretrained_path}")
154163
misc.load_pretrained(args.pretrained_path, args.pretrained_type, model)
@@ -158,11 +167,7 @@ def main(args):
158167
if args.resume:
159168
misc.resume_stage1(args, model_without_FSDP=model)
160169

161-
mixed_precision_dtype = {
162-
"fp16": torch.float16,
163-
"bf16": torch.bfloat16,
164-
"tf32": torch.float32,
165-
}[args.precision]
170+
166171
TransformerBlock = type(model.llma.layers[0])
167172

168173
model = FSDP(

accessory/model/LLM/llama.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from apex.normalization import FusedRMSNorm as RMSNorm
2121
import open_clip
2222

23+
from util.tensor_type import default_tensor_type
2324
import configs.global_configs
2425
if configs.global_configs.USE_FLASH_ATTENTION:
2526
from flash_attn import flash_attn_func
@@ -308,9 +309,8 @@ def __init__(self, params: ModelArgs, with_visual=False):
308309
self.cache_image_words = 0 # for inference
309310
if with_visual:
310311
print("build llama model with clip")
311-
torch.set_default_tensor_type(torch.cuda.HalfTensor)
312-
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
313-
torch.set_default_tensor_type(torch.FloatTensor)
312+
with default_tensor_type(dtype=torch.half):
313+
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
314314
for name, param in self.clip.named_parameters():
315315
param.requires_grad = False
316316
in_dim = self.clip.visual.proj.shape[1]
@@ -334,9 +334,7 @@ def get_trainable_params(self):
334334
def set_default_trainability(self):
335335
for key, value in self.named_parameters():
336336
value.requires_grad = False
337-
value.data = value.data.half()
338337
for key, value in self.get_trainable_params().items():
339-
value.data = value.data.float()
340338
value.requires_grad = True
341339

342340

@@ -366,8 +364,10 @@ def clip_encode_image(self, x):
366364

367365

368366
def encode_image(self, image):
369-
# return self.patch_embed(image)
370-
image_tokens = self.clip_encode_image(image)
367+
with torch.cuda.amp.autocast(enabled=False):
368+
image = image.half()
369+
image_tokens = self.clip_encode_image(image)
370+
image = image.to(self.clip_proj.weight.dtype)
371371
image_tokens = self.clip_proj_norm(self.clip_proj(image_tokens))
372372
return image_tokens
373373

accessory/model/LLM/llama_adapter.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import configs.global_configs
2525
if configs.global_configs.USE_FLASH_ATTENTION:
2626
from flash_attn import flash_attn_func
27+
from util.tensor_type import default_tensor_type
2728

2829
default_linear_init = functools.partial(nn.init.kaiming_uniform_, a=math.sqrt(5))
2930

@@ -349,9 +350,8 @@ def __init__(self, params: ModelArgs, with_visual=False):
349350
self.image_words = 0
350351
if with_visual:
351352
print("build llama model with clip")
352-
torch.set_default_tensor_type(torch.cuda.HalfTensor)
353-
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
354-
torch.set_default_tensor_type(torch.FloatTensor)
353+
with default_tensor_type(dtype=torch.half):
354+
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
355355
for name, param in self.clip.named_parameters():
356356
param.requires_grad = False
357357
in_dim = self.clip.visual.proj.shape[1]
@@ -401,9 +401,7 @@ def get_trainable_params(self):
401401
def set_default_trainability(self):
402402
for key, value in self.named_parameters():
403403
value.requires_grad = False
404-
value.data = value.data.half()
405404
for key, value in self.get_trainable_params().items():
406-
value.data = value.data.float()
407405
value.requires_grad = True
408406

409407

accessory/model/LLM/llama_peft.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ColumnParallelLinear
1818
)
1919
from ..peft import LoraColumnParallelLinear, LoraRowParallelLinear
20+
from util.tensor_type import default_tensor_type
2021

2122
from apex.normalization import FusedRMSNorm as RMSNorm
2223
import open_clip
@@ -323,9 +324,8 @@ def __init__(self, params: ModelArgs, with_visual=False):
323324
self.cache_image_words = 0 # for inference
324325
if with_visual:
325326
print("build llama model with clip")
326-
torch.set_default_tensor_type(torch.cuda.HalfTensor)
327-
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
328-
torch.set_default_tensor_type(torch.FloatTensor)
327+
with default_tensor_type(dtype=torch.half):
328+
self.clip, _, _ = open_clip.create_model_and_transforms('ViT-L-14', pretrained='openai')
329329
for name, param in self.clip.named_parameters():
330330
param.requires_grad = False
331331
in_dim = self.clip.visual.proj.shape[1]
@@ -351,9 +351,7 @@ def get_trainable_params(self):
351351
def set_default_trainability(self):
352352
for key, value in self.named_parameters():
353353
value.requires_grad = False
354-
value.data = value.data.half()
355354
for key, value in self.get_trainable_params().items():
356-
value.data = value.data.float()
357355
value.requires_grad = True
358356

359357

accessory/model/LLM/llama_qformerv2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,9 +337,7 @@ def get_trainable_params(self):
337337
def set_default_trainability(self):
338338
for key, value in self.named_parameters():
339339
value.requires_grad = False
340-
value.data = value.data.half()
341340
for key, value in self.get_trainable_params().items():
342-
value.data = value.data.float()
343341
value.requires_grad = True
344342

345343

accessory/model/meta.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def __init__(
4848
for name, param in self.named_parameters():
4949
is_model_parallel = getattr(param, "is_model_parallel", False)
5050
if param.requires_grad:
51-
print(f"Trainable param: {name}, local_size: {param.shape}, model_parallel: {is_model_parallel}, dtype: {param.dtype}")
5251
if is_model_parallel:
5352
param_count_all += param.numel() * fs_init.get_model_parallel_world_size()
5453
else:

accessory/util/misc.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,8 +351,13 @@ def _save_model():
351351
model_trainable_params = model.get_trainable_params()
352352
model_trainable_params = ['.'.join([_ for _ in key.split('.') if not _.startswith('_')])
353353
for key in model_trainable_params.keys()]
354+
save_dtype = {
355+
"fp16": torch.float16,
356+
"bf16": torch.bfloat16,
357+
"tf32": torch.float,
358+
}[args.precision]
354359
consolidated_model_state_dict = {
355-
"model": {key: val.half() for key, val in model.state_dict().items() if key in model_trainable_params},
360+
"model": {key: val.to(save_dtype) for key, val in model.state_dict().items() if key in model_trainable_params},
356361
}
357362
save_path = os.path.join(
358363
save_dir,
@@ -608,3 +613,9 @@ def mark_mp_params(model: torch.nn.Module):
608613
if isinstance(m, ParallelEmbedding):
609614
m.weight.is_model_parallel = True
610615

616+
617+
def print_trainable_params(model: torch.nn.Module) -> None:
618+
for name, param in model.named_parameters():
619+
is_model_parallel = getattr(param, "is_model_parallel", False)
620+
print(f"Trainable param: {name}, local_size: {param.shape}, model_parallel: {is_model_parallel}, dtype: {param.dtype}")
621+

0 commit comments

Comments
 (0)