Skip to content

Commit fc5ffc2

Browse files
committed
better autobatching style
1 parent c57f9da commit fc5ffc2

1 file changed

Lines changed: 19 additions & 6 deletions

File tree

examples/mnist/mnist-autobatch.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# To run this, download the four files from http://yann.lecun.com/exdb/mnist/
1010
# and gunzip them into a single path. Pass this path to the program with the
1111
# --path option. You will also want to run with --dynet_autobatch 1.
12+
# To turn on GPU training, run with --dynet-gpus 1.
1213

1314
parser = argparse.ArgumentParser()
1415
parser.add_argument("--path", default=".",
@@ -18,6 +19,8 @@
1819
parser.add_argument("--conv", dest="conv", action="store_true")
1920
parser.add_argument("--dynet_autobatch", default=0,
2021
help="Set to 1 to turn on autobatching.")
22+
parser.add_argument("--dynet-gpus", default=0,
23+
help="Set to 1 to train on GPU.")
2124

2225
HIDDEN_DIM = 1024
2326
DROPOUT_RATE = 0.4
@@ -118,13 +121,18 @@ def __call__(self, x, dropout=False):
118121
loss = dy.pickneglogsoftmax(logits, lbl)
119122
losses.append(loss)
120123
mbloss = dy.esum(losses) / mbsize
124+
mbloss.backward()
125+
sgd.update()
126+
127+
# eloss is an exponentially smoothed loss.
121128
if eloss is None:
122129
eloss = mbloss.scalar_value()
123130
else:
124131
eloss = mbloss.scalar_value() * alpha + eloss * (1.0 - alpha)
125-
mbloss.backward()
126-
sgd.update()
132+
133+
# Do dev evaluation here:
127134
if (i > 0) and (i % dev_report == 0):
135+
confusion = [[0 for _ in xrange(10)] for _ in xrange(10)]
128136
correct = 0
129137
dev_start = time.time()
130138
for s in range(0, len(testing), args.minibatch_size):
@@ -137,21 +145,26 @@ def __call__(self, x, dropout=False):
137145
x = dy.inputVector(img)
138146
logits = classify(x)
139147
scores.append((lbl, logits))
140-
# we want to evaluate all the logits in a batch, this is a hack
141-
# to do this.
142-
dummy = dy.esum([logits for _, logits in scores])
143-
dummy.forward()
148+
149+
# This evaluates all the logits in a batch if autobatching is on.
150+
dy.forward([logits for _, logits in scores])
144151

145152
# now we can retrieve the batch-computed logits cheaply
146153
for lbl, logits in scores:
147154
prediction = np.argmax(logits.npvalue())
148155
if lbl == prediction:
149156
correct += 1
157+
confusion[prediction][lbl] += 1
150158
dev_end = time.time()
151159
acc = float(correct) / len(testing)
152160
dev_time += dev_end - dev_start
153161
print("Held out accuracy {} ({} instances/sec)".format(
154162
acc, len(testing) / (dev_end - dev_start)))
163+
print ' ' + ''.join(('T'+str(x)).ljust(6) for x in xrange(10))
164+
for p, row in enumerate(confusion):
165+
s = 'P' + str(p) + ' '
166+
s += ''.join(str(col).ljust(6) for col in row)
167+
print(s)
155168

156169
if (i > 0) and (i % report == 0):
157170
print("moving avg loss: {}".format(eloss))

0 commit comments

Comments
 (0)