|
| 1 | +package BIDMach.allreduce.binder |
| 2 | + |
| 3 | +import java.util.ArrayDeque |
| 4 | +import java.util.concurrent.atomic.AtomicInteger |
| 5 | +import java.util.logging.Logger |
| 6 | +import scala.util.Random |
| 7 | + |
| 8 | +import BIDMach.allreduce.binder.AllreduceBinder.{DataSink, DataSource} |
| 9 | +//import BIDMach.models.Model |
| 10 | +import BIDMach.updaters.Grad |
| 11 | +import BIDMat.{Mat, FMat, GMat} |
| 12 | + |
| 13 | + |
| 14 | +/** |
| 15 | + * Linearize input model mats, and elastic-average update to the same model. |
| 16 | + * Perform momentum exchange among several nodes in a cluster, preserving total energy of the nodes. |
| 17 | + * |
| 18 | + * @param model |
| 19 | + * @param alphaFromIter |
| 20 | + */ |
| 21 | +// FIXME: should get rndseed, node num and # nodes from worker |
| 22 | +class ElasticAverageCollideBinder(updater: Grad, alphaFromIter: Int => Float, hardness: Float, rndseed: Long, inode: Int, |
| 23 | + nnodes: Int, logger: Logger) extends AllreduceBinder { |
| 24 | + |
| 25 | + val model = updater.model |
| 26 | + // Keeping track of elastic updates |
| 27 | + var tic = System.currentTimeMillis() |
| 28 | + val reduceCount = new AtomicInteger() |
| 29 | + |
| 30 | + val random = new Random(rndseed) |
| 31 | + // TODO: make these GMats when applicable |
| 32 | + val rawRandVecs = new Array[Array[FMat]](nnodes) |
| 33 | + val randVecs = new Array[Array[FMat]](nnodes) |
| 34 | + val randVecSqNorms = new Array[Array[Float]](nnodes) |
| 35 | + var rvOffset = 0 |
| 36 | + // TODO: think about GMats too |
| 37 | + val aelem = FMat(1, 1) |
| 38 | + |
| 39 | + // TODO: make this more efficient by making use of functionality in SciFunctions etc. |
| 40 | + def genRandomVector(out: FMat) = { |
| 41 | + var i = 0 |
| 42 | + val len = out.length |
| 43 | + while (i < len) { |
| 44 | + out.data(i) = random.nextGaussian().toFloat |
| 45 | + } |
| 46 | + } |
| 47 | + |
| 48 | + def dotprod(a:Mat, b:Mat):Float = { |
| 49 | + aelem ~ a.contents dot b.contents |
| 50 | + aelem.dv.toFloat; |
| 51 | + } |
| 52 | + |
| 53 | + // TODO: is synchronization necessary to get updater momentum lengths |
| 54 | + def initRandVecs = { |
| 55 | + if (rawRandVecs(0) eq null) { |
| 56 | + for (i <- 0 until nnodes) { |
| 57 | + rawRandVecs(i) = new Array(updater.momentum.length) |
| 58 | + |
| 59 | + for ((pm, i) <- updater.momentum.iterator.zipWithIndex) { |
| 60 | + val fmat = FMat.make(pm.dims) |
| 61 | + genRandomVector(fmat.contents()) |
| 62 | + pm match { |
| 63 | + case _: GMat => rawRandVecs(0)(i) = GMat(fmat) |
| 64 | + case _: FMat => rawRandVecs(0)(i) = fmat |
| 65 | + } |
| 66 | + } |
| 67 | + |
| 68 | + randVecs(i) = new Array(updater.momentum.length) |
| 69 | + randVecSqNorms(i) = new Array(updater.momentum.length) |
| 70 | + for (j <- 0 until updater.momentum.length) { |
| 71 | + randVecs(i)(j) = rawRandVecs(i)(j) - rawRandVecs((i + 1) % nnodes)(j) |
| 72 | + randVecSqNorms(i)(j) = dotprod(randVecs(i)(j), randVecs(i)(j)) |
| 73 | + } |
| 74 | + } |
| 75 | + } |
| 76 | + } |
| 77 | + |
| 78 | + def rotateRndVecs = { |
| 79 | + val prevOffset = (rvOffset + nnodes - 1) % nnodes |
| 80 | + |
| 81 | + for (randMat <- rawRandVecs(rvOffset)) { |
| 82 | + randMat match { |
| 83 | + case gmat: GMat => |
| 84 | + val fmat = FMat.make(randMat.dims) |
| 85 | + genRandomVector(fmat) |
| 86 | + gmat <-- fmat |
| 87 | + case fmat: FMat => genRandomVector(fmat) |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + for (offset <- Array(prevOffset, rvOffset)) { |
| 92 | + val nextOffset = (offset + 1) % nnodes |
| 93 | + for ((v1, v2) <- randVecs(offset) zip randVecs(nextOffset)) { |
| 94 | + v1 ~ v1 - v2 |
| 95 | + } |
| 96 | + for ((v, i) <- randVecs(offset).iterator.zipWithIndex) { |
| 97 | + randVecSqNorms(offset)(i) = dotprod(v, v) |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + rvOffset += 1 |
| 102 | + if (rvOffset == nnodes) rvOffset = 0 |
| 103 | + } |
| 104 | + |
| 105 | + override lazy val totalDataSize: Int = { |
| 106 | + var ret = 0 |
| 107 | + updater.momentum.synchronized { |
| 108 | + // Momentum mats |
| 109 | + for (p <- updater.momentum) ret += p.length |
| 110 | + // Squared magnitudes of momentum mats |
| 111 | + ret += updater.momentum.length |
| 112 | + // Dot product of momentum mats and random mats |
| 113 | + ret += updater.momentum.length |
| 114 | + } |
| 115 | + // Model mats |
| 116 | + model.modelmats.synchronized { |
| 117 | + for (mat <- model.modelmats) ret += mat.length |
| 118 | + } |
| 119 | + ret |
| 120 | + } |
| 121 | + |
| 122 | + override def dataSource: DataSource = inputRequest => { |
| 123 | + initRandVecs |
| 124 | + |
| 125 | + val ret: Array[Float] = new Array[Float](totalDataSize) |
| 126 | + var current = totalDataSize |
| 127 | + val myRandVecs = randVecs((rvOffset + inode) % nnodes) |
| 128 | + |
| 129 | + // TODO: do we need to lock on the model and updater mats |
| 130 | + |
| 131 | + // backward traversing model mats, assuming forward traversal by the training model |
| 132 | + for (mm <- model.modelmats.reverseIterator) { |
| 133 | + current -= mm.length |
| 134 | + mm match { |
| 135 | + case gmat: GMat => GMat.GPUtoCPUarraycopy(gmat.pdata, 0, ret, current, gmat.length, "ElasticAverageBinder dataSource") |
| 136 | + case fmat: FMat => System.arraycopy(fmat.contents().data, 0, ret, current, fmat.length) |
| 137 | + } |
| 138 | + } |
| 139 | + |
| 140 | + // dot product of momentum and random vectors |
| 141 | + // backward traversing update mats, assuming forward traversal by updater |
| 142 | + for ((pm, r) <- updater.momentum.reverseIterator zip myRandVecs.reverseIterator) { |
| 143 | + current -= 1 |
| 144 | + ret(current) = dotprod(pm, r) |
| 145 | + } |
| 146 | + |
| 147 | + // squared norm of momentums |
| 148 | + for (pm <- updater.momentum.reverseIterator) { |
| 149 | + current -= 1 |
| 150 | + ret(current) = dotprod(pm, pm) |
| 151 | + } |
| 152 | + |
| 153 | + // backward traversing update mats, assuming forward traversal by updater |
| 154 | + for (pm <- updater.momentum.reverseIterator) { |
| 155 | + current -= pm.length |
| 156 | + pm match { |
| 157 | + case gmat: GMat => GMat.GPUtoCPUarraycopy(gmat.pdata, 0, ret, current, gmat.length, "ElasticAverageBinder dataSource") |
| 158 | + case fmat: FMat => System.arraycopy(fmat.contents().data, 0, ret, current, fmat.length) |
| 159 | + } |
| 160 | + } |
| 161 | + |
| 162 | + assert(current == 0, "current should be zero after iteration") |
| 163 | + |
| 164 | + AllReduceInput(ret) |
| 165 | + |
| 166 | + } |
| 167 | + |
| 168 | + |
| 169 | + |
| 170 | + override def dataSink: DataSink = reducedOutput => { |
| 171 | + |
| 172 | + reduceCount.synchronized { |
| 173 | + val currentCount: Int = reduceCount.getAndIncrement() |
| 174 | + val updateCounts = 10 |
| 175 | + if (currentCount % updateCounts == 0) { |
| 176 | + val toc = System.currentTimeMillis() |
| 177 | + if (currentCount > 0) { |
| 178 | + logger.info(f"elastic_updates/s=${updateCounts/((toc - tic) / 1.0e3)}%2.2f, total_updates=$currentCount") |
| 179 | + } |
| 180 | + tic = toc |
| 181 | + } |
| 182 | + } |
| 183 | + val reducedData = reducedOutput.data |
| 184 | + |
| 185 | + assert(reducedData.length == totalDataSize, "Reduced output should be same length as input") |
| 186 | + |
| 187 | + // backward traversing model mats, assuming forward traversal by the training model |
| 188 | + // using while instead of for loop due to performance |
| 189 | + var current = totalDataSize |
| 190 | + val alpha = alphaFromIter(reducedOutput.iteration) |
| 191 | + |
| 192 | + for (mm <- model.modelmats.reverseIterator) { |
| 193 | + current -= mm.length |
| 194 | + mm.synchronized { |
| 195 | + mm match { |
| 196 | + case gmat: GMat => |
| 197 | + val gReduced = GMat.make(gmat.dims) |
| 198 | + GMat.CPUtoGPUarraycopy(reducedData, current, gReduced.pdata, 0, gmat.length, "ElasticAverageCollideBinder dataSink") |
| 199 | + gReduced ~ gReduced / aelem.set(nnodes) |
| 200 | + gmat ~ gmat * aelem.set(1 - alpha) |
| 201 | + gReduced ~ gReduced * aelem.set(alpha) |
| 202 | + gmat ~ gReduced + gmat |
| 203 | + gReduced.free() |
| 204 | + case fmat: FMat => |
| 205 | + val fReduced = FMat.make(fmat.dims) |
| 206 | + System.arraycopy(reducedData, current, fReduced.contents().data, 0, fmat.length) |
| 207 | + fReduced ~ fReduced / aelem.set(nnodes) |
| 208 | + fmat ~ fmat * aelem.set(1 - alpha) |
| 209 | + fReduced ~ fReduced * aelem.set(alpha) |
| 210 | + fmat ~ fReduced + fmat |
| 211 | + } |
| 212 | + } |
| 213 | + } |
| 214 | + |
| 215 | + val sumPmR = new Array[Float](updater.modelmats.length) |
| 216 | + current -= updater.modelmats.length |
| 217 | + System.arraycopy(reducedData, current, sumPmR, 0, updater.modelmats.length) |
| 218 | + |
| 219 | + val sumPmPm = new Array[Float](updater.modelmats.length) |
| 220 | + current -= updater.modelmats.length |
| 221 | + System.arraycopy(reducedData, current, sumPmPm, 0, updater.modelmats.length) |
| 222 | + |
| 223 | + val meanP = new Array[Mat](updater.modelmats.length) |
| 224 | + for (i <- updater.modelmats.length - 1 to 0 by -1) { |
| 225 | + current -= updater.modelmats(i).length |
| 226 | + val pbar = updater.modelmats(i) match { |
| 227 | + case _: GMat => |
| 228 | + val pbar = GMat.make(updater.modelmats(i).dims) |
| 229 | + GMat.CPUtoGPUarraycopy(reducedData, current, pbar.pdata, 0, updater.modelmats(i).length, "ElasticAverageCollideBinder dataSink") |
| 230 | + pbar |
| 231 | + case _: FMat => |
| 232 | + val pbar = FMat.make(updater.modelmats(i).dims) |
| 233 | + System.arraycopy(reducedData, current, pbar.contents().data, 0, updater.modelmats(i).length) |
| 234 | + pbar |
| 235 | + } |
| 236 | + pbar ~ pbar / aelem.set(nnodes) |
| 237 | + meanP(i) = pbar |
| 238 | + } |
| 239 | + |
| 240 | + assert(current == 0, "current should be zero after iteration") |
| 241 | + |
| 242 | + for (j <- updater.modelmats.length - 1 to 0 by -1) { |
| 243 | + // TODO: not hold the lock for 1293579813753 years, but also avoid data races |
| 244 | + updater.modelmats(j) synchronized { |
| 245 | + val x = meanP(j) - updater.modelmats(j) |
| 246 | + x ~ x * aelem.set(hardness) |
| 247 | + x ~ x + updater.modelmats(j) |
| 248 | + |
| 249 | + val sumC = randVecs(0)(j).zerosLike |
| 250 | + for (i <- 0 until nnodes) sumC ~ sumC + randVecs(i)(j) |
| 251 | + val sumXR = (1 - hardness) * sumPmR(j) + hardness * dotprod(meanP(j), sumC) |
| 252 | + val sumXX = (1 - hardness * hardness) * sumPmPm(j) - nnodes * (hardness - 1) * (hardness - 1) * dotprod(meanP(j), meanP(j)) |
| 253 | + |
| 254 | + val twoSumXR = 2 * sumXR |
| 255 | + val sumRR = randVecSqNorms.map(_(j)).reduce(_ + _) |
| 256 | + // Discriminant should always be positive for any hardness in [0, 1] (actually, [0, 2]) |
| 257 | + val discr = twoSumXR*twoSumXR - 4*sumRR*(sumXX - sumPmPm(j)) |
| 258 | + val epsilon = 1e-36f |
| 259 | + val beta = if (Mat.myrand.nextFloat() < 0.5f) { |
| 260 | + (-twoSumXR + math.sqrt(discr).toFloat) / (2 * sumRR + epsilon) |
| 261 | + } else { |
| 262 | + (-twoSumXR - math.sqrt(discr).toFloat) / (2 * sumRR + epsilon) |
| 263 | + } |
| 264 | + |
| 265 | + updater.modelmats(j) ~ x - aelem.set(beta) * randVecs((rvOffset + inode) % nnodes)(j) |
| 266 | + } |
| 267 | + } |
| 268 | + |
| 269 | + rotateRndVecs |
| 270 | + } |
| 271 | + |
| 272 | +} |
| 273 | + |
0 commit comments