-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict_convex_hull.py
More file actions
73 lines (56 loc) · 2.15 KB
/
predict_convex_hull.py
File metadata and controls
73 lines (56 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import math
import pickle
import random
from argparse import ArgumentParser
import haiku as hk
import jax
import jax.numpy as jnp
from data import generate_data_record, plot_points_and_hull
from run_training import _create_network
@hk.transform_with_state
def generate(seq, args):
net = _create_network(args)
init_state = net.lstm.initial_state(1)
seq_ = jnp.asarray([(0., *e) for e in seq])[None]
seq_ = jnp.swapaxes(seq_, 0, 1)
encoder_hx, state = hk.dynamic_unroll(net.lstm.core, seq_, init_state)
encoder_hx = jnp.concatenate([encoder_hx, init_state.hidden[None]], axis=0)
hull = []
# = out[-1]
encoder_value = net.enc_att_fc(encoder_hx)
with open('/tmp/encoded_value.pk', 'wb') as f:
pickle.dump(jax.device_get(encoder_value), f)
inp = jnp.asarray([[1., 0., 0.]]) # [start] token
queries = []
for _ in range(len(seq) + 1):
hidden, state = net.lstm.core(inp, state)
decoder_query = net.dec_att_fc(hidden)
queries.append(decoder_query)
# logits = net.energy_fc(jnp.tanh(encoder_value + decoder_query[None]))
logits = encoder_value * decoder_query[None]
logits = jnp.sum(logits, axis=-1) / math.sqrt(logits.shape[-1])
idx = jnp.argmax(logits, axis=0).item()
if idx == len(seq):
break
hull.append(idx)
inp = seq_[idx]
queries = jnp.concatenate(queries, axis=0)
with open('/tmp/decoder_query.pk', 'wb') as f:
pickle.dump(jax.device_get(queries), f)
plot_points_and_hull(seq, hull, 'imgs/prediction.png')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('-c', '--checkpoint-filepath', default=None, type=str)
parser.add_argument('-d', '--rnn-hidden-size', default=256, type=int)
parser.add_argument('-s', '--random-seed', default=1111, type=int)
parser.add_argument('-l', '--num-vertex', default=20, type=int)
args = parser.parse_args()
print(args)
with open(args.checkpoint_filepath, 'rb') as f:
state = pickle.load(f)
rng = random.Random(args.random_seed)
seq, hull = generate_data_record(rng, args.num_vertex)
del hull
with open('/tmp/points.pk', 'wb') as f:
pickle.dump(seq, f)
generate.apply(state.params, state.aux, state.rng, seq, args)