Skip to content

Commit a183b98

Browse files
committed
First draft of ElasticAverageCollideBinder
1 parent 5f84579 commit a183b98

1 file changed

Lines changed: 273 additions & 0 deletions

File tree

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
package BIDMach.allreduce.binder
2+
3+
import java.util.ArrayDeque
4+
import java.util.concurrent.atomic.AtomicInteger
5+
import java.util.logging.Logger
6+
import scala.util.Random
7+
8+
import BIDMach.allreduce.binder.AllreduceBinder.{DataSink, DataSource}
9+
//import BIDMach.models.Model
10+
import BIDMach.updaters.Grad
11+
import BIDMat.{Mat, FMat, GMat}
12+
13+
14+
/**
15+
* Linearize input model mats, and elastic-average update to the same model.
16+
* Perform momentum exchange among several nodes in a cluster, preserving total energy of the nodes.
17+
*
18+
* @param model
19+
* @param alphaFromIter
20+
*/
21+
// FIXME: should get rndseed, node num and # nodes from worker
22+
class ElasticAverageCollideBinder(updater: Grad, alphaFromIter: Int => Float, hardness: Float, rndseed: Long, inode: Int,
23+
nnodes: Int, logger: Logger) extends AllreduceBinder {
24+
25+
val model = updater.model
26+
// Keeping track of elastic updates
27+
var tic = System.currentTimeMillis()
28+
val reduceCount = new AtomicInteger()
29+
30+
val random = new Random(rndseed)
31+
// TODO: make these GMats when applicable
32+
val rawRandVecs = new Array[Array[FMat]](nnodes)
33+
val randVecs = new Array[Array[FMat]](nnodes)
34+
val randVecSqNorms = new Array[Array[Float]](nnodes)
35+
var rvOffset = 0
36+
// TODO: think about GMats too
37+
val aelem = FMat(1, 1)
38+
39+
// TODO: make this more efficient by making use of functionality in SciFunctions etc.
40+
def genRandomVector(out: FMat) = {
41+
var i = 0
42+
val len = out.length
43+
while (i < len) {
44+
out.data(i) = random.nextGaussian().toFloat
45+
}
46+
}
47+
48+
def dotprod(a:Mat, b:Mat):Float = {
49+
aelem ~ a.contents dot b.contents
50+
aelem.dv.toFloat;
51+
}
52+
53+
// TODO: is synchronization necessary to get updater momentum lengths
54+
def initRandVecs = {
55+
if (rawRandVecs(0) eq null) {
56+
for (i <- 0 until nnodes) {
57+
rawRandVecs(i) = new Array(updater.momentum.length)
58+
59+
for ((pm, i) <- updater.momentum.iterator.zipWithIndex) {
60+
val fmat = FMat.make(pm.dims)
61+
genRandomVector(fmat.contents())
62+
pm match {
63+
case _: GMat => rawRandVecs(0)(i) = GMat(fmat)
64+
case _: FMat => rawRandVecs(0)(i) = fmat
65+
}
66+
}
67+
68+
randVecs(i) = new Array(updater.momentum.length)
69+
randVecSqNorms(i) = new Array(updater.momentum.length)
70+
for (j <- 0 until updater.momentum.length) {
71+
randVecs(i)(j) = rawRandVecs(i)(j) - rawRandVecs((i + 1) % nnodes)(j)
72+
randVecSqNorms(i)(j) = dotprod(randVecs(i)(j), randVecs(i)(j))
73+
}
74+
}
75+
}
76+
}
77+
78+
def rotateRndVecs = {
79+
val prevOffset = (rvOffset + nnodes - 1) % nnodes
80+
81+
for (randMat <- rawRandVecs(rvOffset)) {
82+
randMat match {
83+
case gmat: GMat =>
84+
val fmat = FMat.make(randMat.dims)
85+
genRandomVector(fmat)
86+
gmat <-- fmat
87+
case fmat: FMat => genRandomVector(fmat)
88+
}
89+
}
90+
91+
for (offset <- Array(prevOffset, rvOffset)) {
92+
val nextOffset = (offset + 1) % nnodes
93+
for ((v1, v2) <- randVecs(offset) zip randVecs(nextOffset)) {
94+
v1 ~ v1 - v2
95+
}
96+
for ((v, i) <- randVecs(offset).iterator.zipWithIndex) {
97+
randVecSqNorms(offset)(i) = dotprod(v, v)
98+
}
99+
}
100+
101+
rvOffset += 1
102+
if (rvOffset == nnodes) rvOffset = 0
103+
}
104+
105+
override lazy val totalDataSize: Int = {
106+
var ret = 0
107+
updater.momentum.synchronized {
108+
// Momentum mats
109+
for (p <- updater.momentum) ret += p.length
110+
// Squared magnitudes of momentum mats
111+
ret += updater.momentum.length
112+
// Dot product of momentum mats and random mats
113+
ret += updater.momentum.length
114+
}
115+
// Model mats
116+
model.modelmats.synchronized {
117+
for (mat <- model.modelmats) ret += mat.length
118+
}
119+
ret
120+
}
121+
122+
override def dataSource: DataSource = inputRequest => {
123+
initRandVecs
124+
125+
val ret: Array[Float] = new Array[Float](totalDataSize)
126+
var current = totalDataSize
127+
val myRandVecs = randVecs((rvOffset + inode) % nnodes)
128+
129+
// TODO: do we need to lock on the model and updater mats
130+
131+
// backward traversing model mats, assuming forward traversal by the training model
132+
for (mm <- model.modelmats.reverseIterator) {
133+
current -= mm.length
134+
mm match {
135+
case gmat: GMat => GMat.GPUtoCPUarraycopy(gmat.pdata, 0, ret, current, gmat.length, "ElasticAverageBinder dataSource")
136+
case fmat: FMat => System.arraycopy(fmat.contents().data, 0, ret, current, fmat.length)
137+
}
138+
}
139+
140+
// dot product of momentum and random vectors
141+
// backward traversing update mats, assuming forward traversal by updater
142+
for ((pm, r) <- updater.momentum.reverseIterator zip myRandVecs.reverseIterator) {
143+
current -= 1
144+
ret(current) = dotprod(pm, r)
145+
}
146+
147+
// squared norm of momentums
148+
for (pm <- updater.momentum.reverseIterator) {
149+
current -= 1
150+
ret(current) = dotprod(pm, pm)
151+
}
152+
153+
// backward traversing update mats, assuming forward traversal by updater
154+
for (pm <- updater.momentum.reverseIterator) {
155+
current -= pm.length
156+
pm match {
157+
case gmat: GMat => GMat.GPUtoCPUarraycopy(gmat.pdata, 0, ret, current, gmat.length, "ElasticAverageBinder dataSource")
158+
case fmat: FMat => System.arraycopy(fmat.contents().data, 0, ret, current, fmat.length)
159+
}
160+
}
161+
162+
assert(current == 0, "current should be zero after iteration")
163+
164+
AllReduceInput(ret)
165+
166+
}
167+
168+
169+
170+
override def dataSink: DataSink = reducedOutput => {
171+
172+
reduceCount.synchronized {
173+
val currentCount: Int = reduceCount.getAndIncrement()
174+
val updateCounts = 10
175+
if (currentCount % updateCounts == 0) {
176+
val toc = System.currentTimeMillis()
177+
if (currentCount > 0) {
178+
logger.info(f"elastic_updates/s=${updateCounts/((toc - tic) / 1.0e3)}%2.2f, total_updates=$currentCount")
179+
}
180+
tic = toc
181+
}
182+
}
183+
val reducedData = reducedOutput.data
184+
185+
assert(reducedData.length == totalDataSize, "Reduced output should be same length as input")
186+
187+
// backward traversing model mats, assuming forward traversal by the training model
188+
// using while instead of for loop due to performance
189+
var current = totalDataSize
190+
val alpha = alphaFromIter(reducedOutput.iteration)
191+
192+
for (mm <- model.modelmats.reverseIterator) {
193+
current -= mm.length
194+
mm.synchronized {
195+
mm match {
196+
case gmat: GMat =>
197+
val gReduced = GMat.make(gmat.dims)
198+
GMat.CPUtoGPUarraycopy(reducedData, current, gReduced.pdata, 0, gmat.length, "ElasticAverageCollideBinder dataSink")
199+
gReduced ~ gReduced / aelem.set(nnodes)
200+
gmat ~ gmat * aelem.set(1 - alpha)
201+
gReduced ~ gReduced * aelem.set(alpha)
202+
gmat ~ gReduced + gmat
203+
gReduced.free()
204+
case fmat: FMat =>
205+
val fReduced = FMat.make(fmat.dims)
206+
System.arraycopy(reducedData, current, fReduced.contents().data, 0, fmat.length)
207+
fReduced ~ fReduced / aelem.set(nnodes)
208+
fmat ~ fmat * aelem.set(1 - alpha)
209+
fReduced ~ fReduced * aelem.set(alpha)
210+
fmat ~ fReduced + fmat
211+
}
212+
}
213+
}
214+
215+
val sumPmR = new Array[Float](updater.modelmats.length)
216+
current -= updater.modelmats.length
217+
System.arraycopy(reducedData, current, sumPmR, 0, updater.modelmats.length)
218+
219+
val sumPmPm = new Array[Float](updater.modelmats.length)
220+
current -= updater.modelmats.length
221+
System.arraycopy(reducedData, current, sumPmPm, 0, updater.modelmats.length)
222+
223+
val meanP = new Array[Mat](updater.modelmats.length)
224+
for (i <- updater.modelmats.length - 1 to 0 by -1) {
225+
current -= updater.modelmats(i).length
226+
val pbar = updater.modelmats(i) match {
227+
case _: GMat =>
228+
val pbar = GMat.make(updater.modelmats(i).dims)
229+
GMat.CPUtoGPUarraycopy(reducedData, current, pbar.pdata, 0, updater.modelmats(i).length, "ElasticAverageCollideBinder dataSink")
230+
pbar
231+
case _: FMat =>
232+
val pbar = FMat.make(updater.modelmats(i).dims)
233+
System.arraycopy(reducedData, current, pbar.contents().data, 0, updater.modelmats(i).length)
234+
pbar
235+
}
236+
pbar ~ pbar / aelem.set(nnodes)
237+
meanP(i) = pbar
238+
}
239+
240+
assert(current == 0, "current should be zero after iteration")
241+
242+
for (j <- updater.modelmats.length - 1 to 0 by -1) {
243+
// TODO: not hold the lock for 1293579813753 years, but also avoid data races
244+
updater.modelmats(j) synchronized {
245+
val x = meanP(j) - updater.modelmats(j)
246+
x ~ x * aelem.set(hardness)
247+
x ~ x + updater.modelmats(j)
248+
249+
val sumC = randVecs(0)(j).zerosLike
250+
for (i <- 0 until nnodes) sumC ~ sumC + randVecs(i)(j)
251+
val sumXR = (1 - hardness) * sumPmR(j) + hardness * dotprod(meanP(j), sumC)
252+
val sumXX = (1 - hardness * hardness) * sumPmPm(j) - nnodes * (hardness - 1) * (hardness - 1) * dotprod(meanP(j), meanP(j))
253+
254+
val twoSumXR = 2 * sumXR
255+
val sumRR = randVecSqNorms.map(_(j)).reduce(_ + _)
256+
// Discriminant should always be positive for any hardness in [0, 1] (actually, [0, 2])
257+
val discr = twoSumXR*twoSumXR - 4*sumRR*(sumXX - sumPmPm(j))
258+
val epsilon = 1e-36f
259+
val beta = if (Mat.myrand.nextFloat() < 0.5f) {
260+
(-twoSumXR + math.sqrt(discr).toFloat) / (2 * sumRR + epsilon)
261+
} else {
262+
(-twoSumXR - math.sqrt(discr).toFloat) / (2 * sumRR + epsilon)
263+
}
264+
265+
updater.modelmats(j) ~ x - aelem.set(beta) * randVecs((rvOffset + inode) % nnodes)(j)
266+
}
267+
}
268+
269+
rotateRndVecs
270+
}
271+
272+
}
273+

0 commit comments

Comments
 (0)