-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathconvert_pth.py
More file actions
36 lines (30 loc) · 979 Bytes
/
convert_pth.py
File metadata and controls
36 lines (30 loc) · 979 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import pickle as pkl
import sys
import torch
if __name__ == "__main__":
input = sys.argv[1]
obj = torch.load(input, map_location="cpu")
obj = obj["state_dict"]
newmodel = {}
for k, v in obj.items():
old_k = k
if "layer" not in k:
k = "stem." + k
for t in [1, 2, 3, 4]:
k = k.replace("layer{}".format(t), "res{}".format(t + 1))
for t in [1, 2, 3]:
k = k.replace("bn{}".format(t), "conv{}.norm".format(t))
k = k.replace("downsample.0", "shortcut")
k = k.replace("downsample.1", "shortcut.norm")
print(old_k, "->", k)
newmodel[k] = v.numpy()
res = {
"model": newmodel,
"__author__": "OpenSelfSup",
"matching_heuristics": True
}
assert sys.argv[2].endswith('.pkl')
with open(sys.argv[2], "wb") as f:
pkl.dump(res, f)