forked from google/ffn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_inf_stencil_for_ckpts.py
More file actions
89 lines (72 loc) · 3 KB
/
generate_inf_stencil_for_ckpts.py
File metadata and controls
89 lines (72 loc) · 3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""Helper script for generating InferenceRequests and storage folders
You have several checkpoints, and maybe you want to make both a
PolicyPeaks and a PolicyInvertOrigins inference request for both.
To use this script, make basic inference requests (here, one for
peaks and one for inverse), and pass those in along with the
checkpoints to this script. This script will write the checkpoint
paths into the inference requests in their `model_checkpoint_path`
field, as well as determine paths for each (here, peaks and inverse)
inference and make those folders.
"""
import argparse
import glob
import os
from os.path import join
from google.protobuf import text_format
from ffn.inference import inference_pb2
# -- args
ap = argparse.ArgumentParser()
ap.add_argument("infreqs", nargs="+")
ap.add_argument("--ckpts", nargs="+")
ap.add_argument("--ckpt-dir")
ap.add_argument("--output-dir")
args = ap.parse_args()
# -- load the inference requests
infreqs = []
for infreq_fn in args.infreqs:
# load proto
infreq = inference_pb2.InferenceRequest()
with open(infreq_fn) as infreq_f:
text_format.Parse(infreq_f.read(), infreq)
# validate that this is a PolicyPeaks or PolicyInvertOrigins
# this script could be extended to other types, but these are
# what I use.
assert infreq.seed_policy in ("PolicyPeaks", "PolicyInvertOrigins")
# ok, save it
infreqs.append(infreq)
# -- figure out checkpoint paths
ckpt_paths = [
ckpt_meta_path[:-len(".meta")]
for ckpt_meta_path in glob.glob(join(args.ckpt_dir, "model.ckpt-*.meta"))
if any(c in ckpt_meta_path for c in args.ckpts)
]
assert len(ckpt_paths) == len(args.ckpts)
# make sure order matches
ckpt_paths = sorted(ckpt_paths, key=lambda x: int(x.split("model.ckpt-")[1]))
ckpts = sorted(args.ckpts, key=int)
# -- write InferenceRequests and make folders
invert_policy = "{{\"segmentation_dir\": \"{peaks_dir}\"}}"
for ckpt, ckpt_path in zip(ckpts, ckpt_paths):
# paths
peaks_dir = join(args.output_dir, f"{ckpt}_A_peaks")
invert_dir = join(args.output_dir, f"{ckpt}_B_invert")
peaks_infreq_fn = join(args.output_dir, f"{ckpt}_A_peaks.pbtxt")
invert_infreq_fn = join(args.output_dir, f"{ckpt}_B_invert.pbtxt")
# make them
os.makedirs(peaks_dir, exist_ok=True)
os.makedirs(invert_dir, exist_ok=True)
for infreq in infreqs:
# write model path
infreq.model_checkpoint_path = ckpt_path
# write policy-specific stuff, and write to disk
if infreq.seed_policy == "PolicyPeaks":
infreq.segmentation_output_dir = peaks_dir
with open(peaks_infreq_fn, "w") as out:
out.write(text_format.MessageToString(infreq))
elif infreq.seed_policy == "PolicyInvertOrigins":
infreq.segmentation_output_dir = invert_dir
infreq.seed_policy_args = invert_policy.format(peaks_dir=peaks_dir)
with open(invert_infreq_fn, "w") as out:
out.write(text_format.MessageToString(infreq))
else:
assert False