Skip to content

Commit 7851e0e

Browse files
author
John Canny
committed
created forward and backward methods for Net class
1 parent 96e1cf5 commit 7851e0e

3 files changed

Lines changed: 74 additions & 64 deletions

File tree

.classpath

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,34 @@
1-
<?xml version="1.0" encoding="UTF-8"?>
2-
<classpath>
3-
<classpathentry kind="src" path="src/main/scala">
4-
<attributes>
5-
<attribute name="org.eclipse.jdt.launching.CLASSPATH_ATTR_LIBRARY_PATH_ENTRY" value="C:/code/BIDMach/lib"/>
6-
</attributes>
7-
</classpathentry>
8-
<classpathentry kind="src" path="src/main/java">
9-
<attributes>
10-
<attribute name="org.eclipse.jdt.launching.CLASSPATH_ATTR_LIBRARY_PATH_ENTRY" value="c:/code/CUDA7/bin"/>
11-
</attributes>
12-
</classpathentry>
13-
<classpathentry kind="con" path="org.scala-ide.sdt.launching.SCALA_CONTAINER"/>
14-
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER"/>
15-
<classpathentry kind="lib" path="lib/lz4-1.3.jar"/>
16-
<classpathentry kind="lib" path="lib/jfreechart-1.0.19.jar"/>
17-
<classpathentry kind="lib" path="lib/jhdf5-3.2.1.jar"/>
18-
<classpathentry kind="lib" path="lib/jline-2.10.jar"/>
19-
<classpathentry kind="lib" path="lib/jocl-2.0.0.jar"/>
20-
<classpathentry kind="lib" path="lib/json-io-4.5.0.jar"/>
21-
<classpathentry kind="lib" path="lib/ptplot-1.0.jar"/>
22-
<classpathentry kind="lib" path="lib/ptplotapplication-1.0.jar"/>
23-
<classpathentry kind="lib" path="lib/scala-arm_2.11-1.4.jar"/>
24-
<classpathentry kind="lib" path="lib/scala-compiler-2.11.0-M8.jar"/>
25-
<classpathentry kind="lib" path="lib/jcublas-0.8.0.jar"/>
26-
<classpathentry kind="lib" path="lib/jcuda-0.8.0.jar"/>
27-
<classpathentry kind="lib" path="lib/jcudnn-0.8.0.jar"/>
28-
<classpathentry kind="lib" path="lib/jcufft-0.8.0.jar"/>
29-
<classpathentry kind="lib" path="lib/jcurand-0.8.0.jar"/>
30-
<classpathentry kind="lib" path="lib/jcusparse-0.8.0.jar"/>
31-
<classpathentry kind="lib" path="lib/protobuf-java-3.1.0.jar"/>
32-
<classpathentry kind="lib" path="lib/BIDMat-2.0.1-cuda8.0beta.jar"/>
33-
<classpathentry kind="output" path="bin"/>
34-
</classpath>
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<classpath>
3+
<classpathentry kind="src" path="src/main/scala">
4+
<attributes>
5+
<attribute name="org.eclipse.jdt.launching.CLASSPATH_ATTR_LIBRARY_PATH_ENTRY" value="C:/code/BIDMach/lib"/>
6+
</attributes>
7+
</classpathentry>
8+
<classpathentry kind="src" path="src/main/java">
9+
<attributes>
10+
<attribute name="org.eclipse.jdt.launching.CLASSPATH_ATTR_LIBRARY_PATH_ENTRY" value="c:/code/CUDA7/bin"/>
11+
</attributes>
12+
</classpathentry>
13+
<classpathentry kind="con" path="org.scala-ide.sdt.launching.SCALA_CONTAINER"/>
14+
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER"/>
15+
<classpathentry kind="lib" path="lib/lz4-1.3.jar"/>
16+
<classpathentry kind="lib" path="lib/jfreechart-1.0.19.jar"/>
17+
<classpathentry kind="lib" path="lib/jhdf5-3.2.1.jar"/>
18+
<classpathentry kind="lib" path="lib/jline-2.10.jar"/>
19+
<classpathentry kind="lib" path="lib/jocl-2.0.0.jar"/>
20+
<classpathentry kind="lib" path="lib/json-io-4.5.0.jar"/>
21+
<classpathentry kind="lib" path="lib/ptplot-1.0.jar"/>
22+
<classpathentry kind="lib" path="lib/ptplotapplication-1.0.jar"/>
23+
<classpathentry kind="lib" path="lib/scala-arm_2.11-1.4.jar"/>
24+
<classpathentry kind="lib" path="lib/scala-compiler-2.11.0-M8.jar"/>
25+
<classpathentry kind="lib" path="lib/jcublas-0.8.0.jar"/>
26+
<classpathentry kind="lib" path="lib/jcuda-0.8.0.jar"/>
27+
<classpathentry kind="lib" path="lib/jcudnn-0.8.0.jar"/>
28+
<classpathentry kind="lib" path="lib/jcufft-0.8.0.jar"/>
29+
<classpathentry kind="lib" path="lib/jcurand-0.8.0.jar"/>
30+
<classpathentry kind="lib" path="lib/jcusparse-0.8.0.jar"/>
31+
<classpathentry kind="lib" path="lib/protobuf-java-3.1.0.jar"/>
32+
<classpathentry kind="lib" path="lib/BIDMat-2.0.2-cuda8.0beta.jar"/>
33+
<classpathentry kind="output" path="bin"/>
34+
</classpath>

scripts/networks/testCIFAR10a.ssc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ val (nn,opts) = Net.learner(trainfname,labelsfname);
1111
val convt = jcuda.jcudnn.cudnnConvolutionMode.CUDNN_CROSS_CORRELATION
1212

1313

14-
opts.batchSize= 64
14+
opts.batchSize= 32
1515
opts.npasses = 10
1616
opts.lrate = 1e-3f
1717
opts.lrate = 1e-4f

src/main/scala/BIDMach/networks/Net.scala

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -132,41 +132,51 @@ class Net(override val opts:Net.Opts = new Net.Options) extends Model(opts) {
132132
}
133133
}
134134

135+
def forward:Int = {
136+
if (mask.asInstanceOf[AnyRef] != null) {
137+
modelmats(0) ~ modelmats(0) mask;
138+
}
139+
var i = 0;
140+
while (i < layers.length) {
141+
if (opts.debug > 0) {
142+
println("dobatch forward %d %s" format (i, layers(i).getClass))
143+
}
144+
layers(i).forward;
145+
i += 1;
146+
}
147+
i;
148+
}
149+
150+
def backward(nl:Int, ipass:Int, pos:Long) = {
151+
var i = nl;
152+
var j = 0;
153+
while (j < output_layers.length) {
154+
output_layers(j).deriv.set(1);
155+
j += 1;
156+
}
157+
if (opts.aopts == null) {
158+
for (j <- 0 until updatemats.length) updatemats(j).clear;
159+
}
160+
while (i > 1) {
161+
i -= 1;
162+
if (opts.debug > 0) {
163+
println("dobatch backward %d %s" format (i, layers(i).getClass))
164+
}
165+
layers(i).backward(ipass, pos);
166+
}
167+
if (mask.asInstanceOf[AnyRef] != null) {
168+
updatemats(0) ~ updatemats(0) mask;
169+
}
170+
}
171+
135172

136173
def dobatch(gmats:Array[Mat], ipass:Int, pos:Long):Unit = {
137174
if (batchSize < 0) batchSize = gmats(0).ncols;
138175
if (batchSize == gmats(0).ncols) { // discard odd-sized minibatches
139176
assignInputs(gmats, ipass, pos);
140177
assignTargets(gmats, ipass, pos);
141-
if (mask.asInstanceOf[AnyRef] != null) {
142-
modelmats(0) ~ modelmats(0) mask;
143-
}
144-
var i = 0;
145-
while (i < layers.length) {
146-
if (opts.debug > 0) {
147-
println("dobatch forward %d %s" format (i, layers(i).getClass))
148-
}
149-
layers(i).forward;
150-
i += 1;
151-
}
152-
var j = 0;
153-
while (j < output_layers.length) {
154-
output_layers(j).deriv.set(1);
155-
j += 1;
156-
}
157-
if (opts.aopts == null) {
158-
for (j <- 0 until updatemats.length) updatemats(j).clear;
159-
}
160-
while (i > 1) {
161-
i -= 1;
162-
if (opts.debug > 0) {
163-
println("dobatch backward %d %s" format (i, layers(i).getClass))
164-
}
165-
layers(i).backward(ipass, pos);
166-
}
167-
if (mask.asInstanceOf[AnyRef] != null) {
168-
updatemats(0) ~ updatemats(0) mask;
169-
}
178+
val nl = forward;
179+
backward(nl, ipass, pos);
170180
}
171181
}
172182

0 commit comments

Comments
 (0)