-
Notifications
You must be signed in to change notification settings - Fork 36
Expand file tree
/
Copy pathbin_convert_pt.py
More file actions
54 lines (48 loc) · 2.16 KB
/
bin_convert_pt.py
File metadata and controls
54 lines (48 loc) · 2.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse
import glob
import os
import torch
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("pretrained_model_path", default=None, type=str)
args = parser.parse_args()
pretrained_model_path = args.pretrained_model_path
total_pretrained_model_path = os.path.join(pretrained_model_path, "model.pt")
saved_transformer3d_path = os.path.join(pretrained_model_path, "transformer.pt")
saved_portrait_encoder_path = os.path.join(pretrained_model_path, "portrait_encoder.pt")
checkpoint_files = glob.glob(os.path.join(pretrained_model_path, "*.bin"))
print(checkpoint_files)
state_dict = {}
for checkpoint_file in checkpoint_files:
checkpoint = torch.load(checkpoint_file, map_location='cpu')
for key, value in checkpoint.items():
if key in state_dict:
print(key)
print(1 / 0)
state_dict[key] = torch.cat([state_dict[key], value], dim=0)
else:
state_dict[key] = value
torch.save(state_dict, total_pretrained_model_path)
del checkpoint
checkpoint = torch.load(total_pretrained_model_path, map_location='cpu')
for key, value in checkpoint.items():
print(key)
print("--------------------------------")
transformer3d = {}
portrait_encoder = {}
for key, value in checkpoint.items():
if key.startswith("transformer3d"):
new_key = key[len("transformer3d."):]
transformer3d[new_key] = value
elif key.startswith("portrait_encoder"):
new_key = key[len("portrait_encoder."):]
portrait_encoder[new_key] = value
torch.save(transformer3d, saved_transformer3d_path)
torch.save(portrait_encoder, saved_portrait_encoder_path)
print("------------------------------------------")
for key, value in transformer3d.items():
print(key)
print("------------------------------------------")
for key, value in portrait_encoder.items():
print(key)
# python bin_convert_pt.py --pretrained_model_path="/path/FlashPortrait/output_14B_dir/checkpoint-x-fp32-infer"