99# To run this, download the four files from http://yann.lecun.com/exdb/mnist/
1010# and gunzip them into a single path. Pass this path to the program with the
1111# --path option. You will also want to run with --dynet_autobatch 1.
12+ # To turn on GPU training, run with --dynet-gpus 1.
1213
1314parser = argparse .ArgumentParser ()
1415parser .add_argument ("--path" , default = "." ,
1819parser .add_argument ("--conv" , dest = "conv" , action = "store_true" )
1920parser .add_argument ("--dynet_autobatch" , default = 0 ,
2021 help = "Set to 1 to turn on autobatching." )
22+ parser .add_argument ("--dynet-gpus" , default = 0 ,
23+ help = "Set to 1 to train on GPU." )
2124
2225HIDDEN_DIM = 1024
2326DROPOUT_RATE = 0.4
@@ -118,13 +121,18 @@ def __call__(self, x, dropout=False):
118121 loss = dy .pickneglogsoftmax (logits , lbl )
119122 losses .append (loss )
120123 mbloss = dy .esum (losses ) / mbsize
124+ mbloss .backward ()
125+ sgd .update ()
126+
127+ # eloss is an exponentially smoothed loss.
121128 if eloss is None :
122129 eloss = mbloss .scalar_value ()
123130 else :
124131 eloss = mbloss .scalar_value () * alpha + eloss * (1.0 - alpha )
125- mbloss . backward ()
126- sgd . update ()
132+
133+ # Do dev evaluation here:
127134 if (i > 0 ) and (i % dev_report == 0 ):
135+ confusion = [[0 for _ in xrange (10 )] for _ in xrange (10 )]
128136 correct = 0
129137 dev_start = time .time ()
130138 for s in range (0 , len (testing ), args .minibatch_size ):
@@ -137,21 +145,26 @@ def __call__(self, x, dropout=False):
137145 x = dy .inputVector (img )
138146 logits = classify (x )
139147 scores .append ((lbl , logits ))
140- # we want to evaluate all the logits in a batch, this is a hack
141- # to do this.
142- dummy = dy .esum ([logits for _ , logits in scores ])
143- dummy .forward ()
148+
149+ # This evaluates all the logits in a batch if autobatching is on.
150+ dy .forward ([logits for _ , logits in scores ])
144151
145152 # now we can retrieve the batch-computed logits cheaply
146153 for lbl , logits in scores :
147154 prediction = np .argmax (logits .npvalue ())
148155 if lbl == prediction :
149156 correct += 1
157+ confusion [prediction ][lbl ] += 1
150158 dev_end = time .time ()
151159 acc = float (correct ) / len (testing )
152160 dev_time += dev_end - dev_start
153161 print ("Held out accuracy {} ({} instances/sec)" .format (
154162 acc , len (testing ) / (dev_end - dev_start )))
163+ print ' ' + '' .join (('T' + str (x )).ljust (6 ) for x in xrange (10 ))
164+ for p , row in enumerate (confusion ):
165+ s = 'P' + str (p ) + ' '
166+ s += '' .join (str (col ).ljust (6 ) for col in row )
167+ print (s )
155168
156169 if (i > 0 ) and (i % report == 0 ):
157170 print ("moving avg loss: {}" .format (eloss ))
0 commit comments