1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnet.module 19 20import org.apache.mxnet.DType.DType 21import org.apache.mxnet._ 22import org.apache.mxnet.module.DataParallelExecutorGroup.Builder 23import org.slf4j.{Logger, LoggerFactory} 24 25import scala.collection.mutable 26import scala.collection.mutable.ArrayBuffer 27 28private object DataParallelExecutorGroup { 29 private val logger: Logger = LoggerFactory.getLogger(classOf[DataParallelExecutorGroup]) 30 // Load a list of arrays into a list of arrays specified by slices 31 private def loadGeneralMulti(data: Seq[NDArray], 32 targets: Seq[Array[((Int, Int), NDArray)]], 33 majorAxis: Seq[Int]): Unit = { 34 for (((dSrc, dTargets), axis) <- data zip targets zip majorAxis) { 35 for (((sliceIdxStart, sliceIdxStop), dDst) <- dTargets) { 36 if (axis >= 0 && (sliceIdxStart > 0 || sliceIdxStop < dSrc.shape(axis))) { 37 // copy slice 38 val shape = dSrc.shape 39 val begin = Array.fill(shape.length)(0) 40 val end = shape.toArray 41 begin(axis) = sliceIdxStart 42 end(axis) = sliceIdxStop 43 if (dSrc.context == dDst.context) { 44 NDArray.crop(Map( 45 "begin" -> new Shape(begin), 46 "end" -> new Shape(end), 47 "out" -> dDst))(dSrc) 48 } else { 49 // on different device, crop and then do cross device copy 50 val dDstCopy: NDArray = NDArray.crop(Map( 51 "begin" -> new Shape(begin), 52 "end" -> new Shape(end)))(dSrc) 53 dDstCopy.copyTo(dDst) 54 } 55 } else { 56 dSrc.copyTo(dDst) 57 } 58 } 59 } 60 } 61 62 private def loadGeneral(data: Seq[NDArray], targets: Seq[NDArray]): Unit = { 63 for ((dSrc, dTarget) <- data zip targets) { 64 dSrc.copyTo(dTarget) 65 } 66 } 67 68 // Load data into sliced arrays 69 private def loadData(batch: DataBatch, 70 targets: Seq[Array[((Int, Int), NDArray)]], 71 majorAxis: Seq[Int]): Unit = { 72 loadGeneralMulti(batch.data, targets, majorAxis) 73 } 74 75 76 // Load label into sliced arrays 77 private def loadLabel(batch: DataBatch, 78 targets: Seq[Array[((Int, Int), NDArray)]], 79 majorAxis: Seq[Int]): Unit = { 80 loadGeneralMulti(batch.label, targets, majorAxis) 81 } 82 83 // Merge outputs that lives on multiple context into one, 84 // so that they look like living on one context. 85 private def mergeMultiContext(outputs: IndexedSeq[IndexedSeq[NDArray]], majorAxis: Seq[Int]) 86 : IndexedSeq[NDArray] = { 87 (outputs zip majorAxis).map { case (tensors, axis) => 88 if (axis >= 0) { 89 NDArray.concatenate(tensors, axis = axis, alwaysCopy = false) 90 } else { 91 // negative axis means the there is no batch_size axis, and all the 92 // results should be the same on each device. We simply take the first one, 93 // without checking they are actually the same 94 tensors(0) 95 } 96 } 97 } 98 99 private object Builder { 100 private[module] def convertGradReq( 101 gradReq: String, argNames: IndexedSeq[String], paramNames: IndexedSeq[String], 102 fixedParamNames: Set[String], dataNames: Seq[String], inputsNeedGrad: Boolean) 103 : Map[String, String] = { 104 require(argNames != null, "Invalid argNames") 105 require(paramNames != null, "Invalid paramNames") 106 require(fixedParamNames != null, "Invalid fixedParamNames") 107 require(dataNames != null, "Invalid dataNames") 108 argNames.map(k => { 109 if (paramNames.contains(k)) { 110 (k, if (fixedParamNames.contains(k)) "null" else gradReq) 111 } else if (dataNames.contains(k)) { 112 (k, if (inputsNeedGrad) gradReq else "null") 113 } else { 114 (k, "null") 115 } 116 }).toMap 117 } 118 } 119 120 class Builder private[module](private val symbol: Symbol, 121 private val contexts: Array[Context], 122 private val paramNames: IndexedSeq[String]) { 123 124 private var workLoadList: IndexedSeq[Float] = null 125 private var dataShapes: IndexedSeq[DataDesc] = null 126 private var labelShapes: Option[IndexedSeq[DataDesc]] = None 127 private var forTraining: Boolean = true 128 private var inputsNeedGrad: Boolean = false 129 private var sharedGroup: Option[DataParallelExecutorGroup] = None 130 private var inputTypes: Option[Map[String, DType]] = None 131 private var fixedParamNames: Set[String] = Set.empty[String] 132 private var gradReqs: Map[String, String] = null 133 134 val argNames = symbol.listArguments() 135 136 def setWorkLoadList(workLoad: IndexedSeq[Float]): Builder = { 137 this.workLoadList = workLoad 138 this 139 } 140 141 def setDataShapes(shapes: IndexedSeq[DataDesc]): Builder = { 142 require(shapes != null, "Invalid shapes") 143 this.dataShapes = shapes 144 this 145 } 146 147 def setDataShapesByName(shapes: IndexedSeq[(String, Shape)]): Builder = { 148 require(shapes != null, "Invalid shapes") 149 this.dataShapes = shapes.map { case (k, s) => new DataDesc(k, s) } 150 this 151 } 152 153 def setLabelShapes(shapes: IndexedSeq[DataDesc]): Builder = { 154 this.labelShapes = Option(shapes) 155 this 156 } 157 158 def setLabelShapesByName(shapes: IndexedSeq[(String, Shape)]): Builder = { 159 this.labelShapes = Option(shapes).map(shapesInst => 160 shapesInst.map { case (k, s) => new DataDesc(k, s) } 161 ) 162 this 163 } 164 165 def setForTraining(forTraining: Boolean): Builder = { 166 this.forTraining = forTraining 167 this 168 } 169 170 def setInputsNeedGrad(needGrad: Boolean): Builder = { 171 this.inputsNeedGrad = needGrad 172 this 173 } 174 175 def setSharedGroup(sharedGroup: DataParallelExecutorGroup): Builder = { 176 this.sharedGroup = Option(sharedGroup) 177 this 178 } 179 180 def setInputTypes(inputTypes: Map[String, DType]): Builder = { 181 this.inputTypes = Option(inputTypes) 182 this 183 } 184 185 def setFixedParamNames(fixedParamNames: Set[String]): Builder = { 186 this.fixedParamNames = Option(fixedParamNames).getOrElse(Set.empty[String]) 187 this 188 } 189 190 def setGradReq(gradReq: Map[String, String]): Builder = { 191 require(dataShapes != null, "dataShapes must be set first") 192 val gradReqTmp = mutable.HashMap.empty[String, String] 193 val dataNames = dataShapes.map(_.name) 194 for (k <- argNames) { 195 if (paramNames.contains(k)) { 196 gradReqTmp.put(k, if (fixedParamNames.contains(k)) "null" else "write") 197 } else if (dataNames.contains(k)) { 198 gradReqTmp.put(k, if (inputsNeedGrad) "write" else "null") 199 } else { 200 gradReqTmp.put(k, "null") 201 gradReqTmp ++= gradReq 202 } 203 } 204 this.gradReqs = gradReqTmp.toMap 205 this 206 } 207 208 def setGradReq(gradReq: String): Builder = { 209 require(dataShapes != null, "dataShapes must be set first") 210 val dataNames = dataShapes.map(_.name) 211 this.gradReqs = Builder.convertGradReq( 212 gradReq, argNames, paramNames, fixedParamNames, dataNames, inputsNeedGrad) 213 this 214 } 215 216 def setGradReq(gradReq: Seq[(String, String)]): Builder = { 217 require(gradReq.size == argNames.size, 218 s"provided number of gradReq (${gradReq.size}) do not match number of args " + 219 s"(${argNames.size})") 220 this.gradReqs = gradReq.toMap 221 this 222 } 223 224 def build(): DataParallelExecutorGroup = { 225 new DataParallelExecutorGroup( 226 symbol, contexts, workLoadList, dataShapes, labelShapes, paramNames, forTraining, 227 inputsNeedGrad, sharedGroup, inputTypes, fixedParamNames, this.gradReqs) 228 } 229 } 230} 231 232/** 233 * DataParallelExecutorGroup is a group of executors that lives on a group of devices. 234 * This is a helper class used to implement data parallelism. Each mini-batch will 235 * be split and run on the devices. 236 * @param symbol The common symbolic computation graph for all executors. 237 * @param contexts A list of contexts. 238 * @param workLoadList If not `None`, could be a list of numbers that 239 * specify the workload to be assigned to different context. 240 * Larger number indicate heavier workload. 241 * @param dataShapes Should be a list of (name, shape) tuples, for the shapes of data. 242 * Note the order is important and should be the same as the order that 243 * the `DataIter` provide the data. 244 * @param labelShapes Should be a list of (name, shape) tuples, for the shapes of label. 245 * Note the order is important and should be the same as the order that 246 * the `DataIter` provide the label. 247 * @param paramNames A list of strings, indicating the names of parameters 248 * (e.g. weights, filters, etc.) in the computation graph. 249 * @param forTraining Indicate whether the executors should be bind for training. 250 * When not doing training, the memory for gradients will not be allocated. 251 * @param inputsNeedGrad Indicate whether the gradients for the input data should be computed. 252 * This is currently not used. 253 * It will be useful for implementing composition of modules. 254 * @param sharedGroup Default is `None`. This is used in bucketing. When not `None`, 255 * it should be a executor group corresponding to a different bucket. 256 * In other words, it will correspond to a different symbol but 257 * with the same set of parameters (e.g. unrolled RNNs with different lengths). 258 * In this case, many memory will be shared. 259 * @param inputTypes Default is `None`. When not `None`, 260 * can be used to specify the data type for each of the data/label inputs. 261 * @param fixedParamNames Indicate parameters to be fixed during training. 262 * Parameters in this list will not allocate space for gradient, 263 * nor do gradient calculation. 264 * @param gradReq Requirement for gradient accumulation. Can be 'write', 'add', or 'null', 265 * be specified for each argument. 266 */ 267class DataParallelExecutorGroup private[module]( 268 symbol: Symbol, 269 contexts: Array[Context], 270 workLoadList: IndexedSeq[Float], 271 var dataShapes: IndexedSeq[DataDesc], 272 var labelShapes: Option[IndexedSeq[DataDesc]] = None, 273 private[module] val paramNames: IndexedSeq[String], 274 forTraining: Boolean, 275 inputsNeedGrad: Boolean, 276 sharedGroup: Option[DataParallelExecutorGroup] = None, 277 inputTypes: Option[Map[String, DType]] = None, 278 fixedParamNames: Set[String] = Set.empty[String], 279 gradReq: Map[String, String] = null) { 280 281 require(symbol != null, "Undefined symbol") 282 require(contexts != null, "Undefined context") 283 284 private val argNames = symbol.listArguments() 285 private val auxNames = symbol.listAuxiliaryStates() 286 287 private val gradReqRun = 288 if (!forTraining) { 289 val dataNames = dataShapes.map(_.name) 290 Builder.convertGradReq("null", 291 argNames, paramNames, fixedParamNames, dataNames, inputsNeedGrad) 292 } else { 293 gradReq 294 } 295 296 private val sharedDataArrays: Array[mutable.Map[String, NDArray]] = 297 sharedGroup.map(_.sharedDataArrays).getOrElse( 298 Array.fill(contexts.length)(mutable.Map.empty[String, NDArray])) 299 300 private var batchSize: Int = -1 301 private var slices: Array[(Int, Int)] = null 302 private var execs: Array[Executor] = null 303 private var dataArrays: Seq[Array[((Int, Int), NDArray)]] = null 304 private var labelArrays: Option[Seq[Array[((Int, Int), NDArray)]]] = None 305 private[module] var paramArrays: IndexedSeq[Array[NDArray]] = null 306 private[module] var gradArrays: IndexedSeq[Array[NDArray]] = null 307 private[module] var auxArrays: IndexedSeq[Array[NDArray]] = null 308 private var inputGradArrays: IndexedSeq[Array[NDArray]] = null 309 310 private var dataLayouts = decideSlices(dataShapes) 311 private var labelLayouts = 312 // call it to make sure labels has the same batch size as data 313 if (labelShapes != None) decideSlices(labelShapes.get) 314 else null 315 316 private val outputLayouts = symbol.listOutputs().map(name => { 317 val sym = symbol.get(name) 318 val layout = sym.attr("__layout__") 319 sym.dispose() 320 DataDesc.getBatchAxis(layout) 321 } 322 ) 323 bindExec(dataShapes, labelShapes, sharedGroup) 324 325 def getBatchSize: Int = batchSize 326 327 /** 328 * Decide the slices for each context according to the workload. 329 * @param dataShapes list of DataDesc(name, shape) specifying 330 * the shapes for the input data or label. 331 */ 332 private def decideSlices(dataShapes: Seq[DataDesc]): Seq[Int] = { 333 require(dataShapes.size > 0, "dataShapes must be non empty") 334 val majorAxis = dataShapes.map(data => DataDesc.getBatchAxis(Option(data.layout))) 335 336 for ((dataDesc, axis) <- dataShapes.zip(majorAxis)) { 337 if (axis != -1) { 338 val batchSize = dataDesc.shape(axis) 339 if (this.batchSize != -1) { 340 require(batchSize == this.batchSize, 341 s"all data must have the same batch size: $batchSize," + 342 s"but ${dataDesc.name} has shape ${dataDesc.shape}") 343 } else { 344 this.batchSize = batchSize 345 require(this.workLoadList != null, "Undefined workLoadList") 346 this.slices = ExecutorManager.splitInputSlice(this.batchSize, this.workLoadList) 347 } 348 } 349 } 350 majorAxis 351 } 352 353 /** 354 * Bind executors on their respective devices. 355 * @param dataShapes DataDesc for input data. 356 * @param labelShapes DataDesc for input labels. 357 * @param sharedGroup 358 * @param reshape 359 */ 360 def bindExec(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]], 361 sharedGroup: Option[DataParallelExecutorGroup], reshape: Boolean = false): Unit = { 362 this.batchSize = -1 363 dataLayouts = decideSlices(dataShapes) 364 labelLayouts = { 365 // call it to make sure labels has the same batch size as data 366 if (labelShapes != None) decideSlices(labelShapes.get) 367 else null 368 } 369 if (reshape) { 370 (0 until contexts.length).foreach { i => 371 val dataShapesSliced = slicedShape(dataShapes, i, dataLayouts) 372 val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts)) 373 val inputShapes 374 = dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape]) 375 376 ResourceScope.usingIfScopeExists(execs(i).scope) { 377 val tmpExec = execs(i).reshape(allowUpSizing = true, kwargs = inputShapes) 378 execs(i).dispose() 379 execs(i) = tmpExec 380 } 381 } 382 } else { 383 execs = (0 until contexts.length).map(i => 384 bindIthExec(i, dataShapes, labelShapes, sharedGroup) 385 ).toArray 386 } 387 388 this.dataShapes = dataShapes 389 this.labelShapes = labelShapes 390 391 // convenient data structures 392 dataArrays = dataShapes.map(dataDesc => 393 this.execs.zipWithIndex.map { case (e, i) => (this.slices(i), e.argDict(dataDesc.name)) } 394 ) 395 396 labelArrays = labelShapes.map(shapes => 397 shapes.map(labelDesc => 398 this.execs.zipWithIndex.map { case (e, i) => (this.slices(i), e.argDict(labelDesc.name)) } 399 ) 400 ) 401 402 paramArrays = argNames.zipWithIndex.withFilter { 403 case (name, i) => paramNames.contains(name) 404 }.map { case (name, i) => 405 execs.map(_.argArrays(i)) 406 } 407 408 gradArrays = 409 if (forTraining) { 410 argNames.zipWithIndex.withFilter { 411 case (name, i) => paramNames.contains(name) 412 }.map { case (name, i) => 413 execs.map(_.gradArrays(i)) 414 } 415 } else { 416 null 417 } 418 419 val dataNames = dataShapes.map(_.name) 420 inputGradArrays = 421 if (inputsNeedGrad) { 422 argNames.zipWithIndex.withFilter { 423 case (name, i) => dataNames.contains(name) 424 }.map { case (name, i) => 425 execs.map(_.gradArrays(i)) 426 } 427 } else { 428 null 429 } 430 431 auxArrays = (0 until auxNames.length).map(i => execs.map(_.auxArrays(i))) 432 } 433 434 /** 435 * Reshape executors. 436 * @param dataShapes 437 * @param labelShapes 438 */ 439 def reshape(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]): Unit = { 440 if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) { 441 this.bindExec(dataShapes, labelShapes, None, reshape = true) 442 } 443 } 444 445 /** 446 * Assign, i.e. copy parameters to all the executors. 447 * @param argParams A dictionary of name to `NDArray` parameter mapping. 448 * @param auxParams A dictionary of name to `NDArray` auxiliary variable mapping. 449 * @param allowExtra hether allow extra parameters that are not needed by symbol. 450 * If this is True, no error will be thrown when argParams or auxParams 451 * contain extra parameters that is not needed by the executor. 452 */ 453 def setParams(argParams: Map[String, NDArray], auxParams: Map[String, NDArray], 454 allowExtra: Boolean = false): Unit = { 455 execs.foreach(_.copyParamsFrom(argParams, auxParams, allowExtraParams = allowExtra)) 456 } 457 458 /** 459 * Copy data from each executor to `arg_params` and `aux_params`. 460 * @param argParams target parameter arrays 461 * @param auxParams target aux arrays 462 * Note this function will inplace update the NDArrays in arg_params and aux_params. 463 */ 464 def getParams(argParams: Map[String, NDArray], auxParams: Map[String, NDArray]): Unit = { 465 for ((name, block) <- paramNames.zip(paramArrays)) { 466 val weight = (block.map(_.copyTo(Context.cpu())).reduce((a: NDArray, b: NDArray) => 467 (a + b).disposeDeps() 468 ) / block.length).disposeDeps() 469 val weightNewType = weight.asType(argParams(name).dtype) 470 weightNewType.copyTo(argParams(name)) 471 weight.dispose() 472 weightNewType.dispose() 473 } 474 for ((name, block) <- auxNames.zip(auxArrays)) { 475 val weight = (block.map(_.copyTo(Context.cpu())).reduce((a: NDArray, b: NDArray) => 476 (a + b).disposeDeps() 477 ) / block.length).disposeDeps() 478 val weightNewType = weight.asType(auxParams(name).dtype) 479 weightNewType.copyTo(auxParams(name)) 480 weight.dispose() 481 weightNewType.dispose() 482 } 483 } 484 485 /** 486 * Split `dataBatch` according to workload and run forward on each devices. 487 * @param dataBatch 488 * @param isTrain The hint for the backend, indicating whether we are during training phase. 489 * Default is `None`, then the value `self.for_training` will be used. 490 */ 491 def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit = { 492 DataParallelExecutorGroup.loadData(dataBatch, dataArrays, dataLayouts) 493 val isTrainOpt = isTrain.getOrElse(this.forTraining) 494 labelArrays.foreach(labels => { 495 require(!isTrainOpt || dataBatch.label != null, "label must be defined if in training phase") 496 if (dataBatch.label != null) { 497 require(labelLayouts != null, "label layouts are undefined") 498 DataParallelExecutorGroup.loadLabel(dataBatch, labels, labelLayouts) 499 } 500 }) 501 execs.foreach(_.forward(isTrainOpt)) 502 } 503 504 // Get the shapes of the outputs. 505 def getOutputShapes: IndexedSeq[(String, Shape)] = { 506 val outputs = execs(0).outputs 507 val shapes = outputs.map(_.shape) 508 (symbol.listOutputs() zip shapes zip outputLayouts) map { case ((key, theShape), axis) => 509 val shape = theShape.toArray 510 if (axis >= 0) { 511 shape(axis) = batchSize 512 } 513 (key, Shape(shape)) 514 } 515 } 516 517 /** 518 * Get outputs of the previous forward computation. 519 * @return In the case when data-parallelism is used, 520 * the outputs will be collected from multiple devices. 521 * The results will look like `[ [out1_dev1, out1_dev2], [out2_dev1, out2_dev2] ]`, 522 * those `NDArray` might live on different devices. 523 */ 524 def getOutputs(): IndexedSeq[IndexedSeq[NDArray]] = { 525 (0 until execs(0).outputs.length).map(i => execs.map(_.outputs(i)).toIndexedSeq) 526 } 527 528 /** 529 * Get outputs of the previous forward computation. 530 * @return In the case when data-parallelism is used, 531 * the outputs will be merged from multiple devices, 532 * as they look like from a single executor. 533 * The results will look like `[out1, out2]` 534 */ 535 def getOutputsMerged(): IndexedSeq[NDArray] = { 536 DataParallelExecutorGroup.mergeMultiContext(getOutputs(), outputLayouts) 537 } 538 539 /** 540 * Get the gradients to the inputs, computed in the previous backward computation. 541 * @return In the case when data-parallelism is used, 542 * the grads will be collected from multiple devices. 543 * The results will look like `[ [grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2] ]`, 544 * those `NDArray` might live on different devices. 545 */ 546 def getInputGrads(): IndexedSeq[IndexedSeq[NDArray]] = { 547 require(inputsNeedGrad, "Cannot get InputGrads when inputNeedGrad is set to false") 548 inputGradArrays.map(_.toIndexedSeq) 549 } 550 551 /** 552 * Get the gradients to the inputs, computed in the previous backward computation. 553 * @return In the case when data-parallelism is used, 554 * the grads will be merged from multiple devices, 555 * as they look like from a single executor. 556 * The results will look like `[grad1, grad2]` 557 */ 558 def getInputGradsMerged(): IndexedSeq[NDArray] = { 559 DataParallelExecutorGroup.mergeMultiContext(getInputGrads(), dataLayouts) 560 } 561 562 /** 563 * Run backward on all devices. A backward should be called after 564 * a call to the forward function. Backward cannot be called unless 565 * `this.for_training` is `True`. 566 * @param outGrads Gradient on the outputs to be propagated back. 567 * This parameter is only needed when bind is called 568 * on outputs that are not a loss function. 569 */ 570 def backward(outGrads: Array[NDArray] = null): Unit = { 571 require(forTraining, "re-bind with forTraining = true to run backward") 572 573 for (((exec, islice), i) <- (execs zip slices).zipWithIndex) { 574 val outGradsSlice = 575 if (outGrads != null) { 576 (outGrads zip outputLayouts).map { case (grad, axis) => 577 if (axis >= 0) { 578 val ogMySlice: NDArray = NDArray.slice_axis( 579 Map("axis" -> axis, "begin" -> islice._1, "end" -> islice._2))(grad) 580 ogMySlice.asInContext(contexts(i)) 581 } else { 582 grad.copyTo(contexts(i)) 583 } 584 } 585 } else { 586 Array.empty[NDArray] 587 } 588 exec.backward(outGrads = outGradsSlice) 589 } 590 } 591 592 /** 593 * Accumulate the performance according to `eval_metric` on all devices. 594 * @param evalMetric The metric used for evaluation. 595 * @param labels Typically comes from `label` of a `DataBatch`. 596 */ 597 def updateMetric(evalMetric: EvalMetric, labels: IndexedSeq[NDArray]): Unit = { 598 for ((texec, islice) <- this.execs zip this.slices) { 599 val labelsSlice = 600 (labels zip this.labelLayouts) map { case (label, axis) => 601 if (axis == 0) { 602 label.slice(islice) 603 } else if (axis > 0) { 604 val labelMySlice: NDArray = NDArray.slice_axis(Map( 605 "axis" -> axis, "begin" -> islice._1, "end" -> islice._2))(label) 606 .asInContext(label.context) 607 labelMySlice 608 } else { 609 label 610 } 611 } 612 613 evalMetric.update(labelsSlice, texec.outputs) 614 615 // Clear up any slices we created (sometimes we don't slice so check for this) 616 (labels zip labelsSlice).foreach { case (label, labelSlice) => 617 if (label ne labelSlice) { 618 labelSlice.dispose() 619 } 620 } 621 } 622 } 623 624 // Internal utility function to bind the i-th executor. 625 private def bindIthExec(i: Int, dataShapes: Seq[DataDesc], 626 labelShapes: Option[Seq[DataDesc]], 627 sharedGroup: Option[DataParallelExecutorGroup]): Executor = { 628 val dataShapesSliced = slicedShape(dataShapes, i, dataLayouts) 629 val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts)) 630 val sharedExec = sharedGroup.map(_.execs(i)) 631 val context = contexts(i) 632 val sharedDataArrays = this.sharedDataArrays(i) 633 634 val inputShapes 635 = dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape]) 636 637 val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes) 638 require(argShapes != null, "Shape inference failed." + 639 s"Known shapes are $inputShapes for symbol arguments ${symbol.listArguments()} " + 640 s"and aux states ${symbol.listAuxiliaryStates()}") 641 642 val inputTypesGot = inputTypes.getOrElse(inputShapes.map { case (k, v) => 643 (k, Base.MX_REAL_TYPE) 644 }) 645 val (argTypes, _, auxTypes) = symbol.inferType(inputTypesGot) 646 require(argTypes != null, "Type inference failed." + 647 s"Known types as $inputTypes for symbol arguments ${symbol.listArguments()} " + 648 s"and aux states ${symbol.listAuxiliaryStates()}") 649 650 val argArrays = ArrayBuffer.empty[NDArray] 651 val gradArrayMap = mutable.HashMap.empty[String, NDArray] 652 653 // create or borrow arguments and gradients 654 for (j <- 0 until argNames.length) { 655 val name = argNames(j) 656 val argArr = 657 if (paramNames.contains(name)) { 658 // model parameter 659 sharedExec match { 660 case None => 661 val argArr = NDArray.zeros(argShapes(j), context, dtype = argTypes(j)) 662 if (gradReqRun(name) != "null") { 663 val gradArr = NDArray.zeros(argShapes(j), context, dtype = argTypes(j)) 664 gradArrayMap.put(name, gradArr) 665 } 666 argArr 667 case Some(sharedExecInst) => 668 val argArr = sharedExecInst.argDict(name) 669 require(argArr.shape == argShapes(j), 670 s"Shape ${argArr.shape} of argument $name does not match " + 671 s"inferred shape ${argShapes(j)}") 672 require(argArr.dtype == argTypes(j), 673 s"Type ${argArr.dtype} of argument $name does not match " + 674 s"inferred type ${argTypes(j)}") 675 if (gradReqRun(name) != "null") { 676 gradArrayMap.put(name, sharedExecInst.gradDict(name)) 677 } 678 argArr 679 } 680 } else { 681 // data or label 682 val argArr = getOrReshape(name, sharedDataArrays, argShapes(j), argTypes(j), context) 683 // data might also need grad if inputs_need_grad is True 684 if (gradReqRun(name) != "null") { 685 gradArrayMap.put(name, 686 getOrReshape(s"grad of $name", sharedDataArrays, argShapes(j), argTypes(j), context)) 687 } 688 argArr 689 } 690 argArrays.append(argArr) 691 } 692 693 // create or borrow aux variables 694 val auxArrays = 695 sharedExec match { 696 case None => (auxShapes zip auxTypes).map { case (s, t) => 697 NDArray.zeros(s, context, dtype = t) 698 }.toArray 699 case Some(sharedExecInst) => 700 for ((arr, j) <- sharedExecInst.auxArrays.zipWithIndex) { 701 require(auxShapes(j) == arr.shape, 702 s"Shape ${arr.shape} of aux variable ${auxNames(j)} does not match " + 703 s"inferred shape ${auxShapes(j)}") 704 require(auxTypes(j) == arr.dtype, 705 s"Type ${arr.dtype} of aux variable ${auxNames(j)} does not match " + 706 s"inferred type ${auxTypes(j)}") 707 } 708 sharedExecInst.auxArrays.map(identity) 709 } 710 symbol.bind(ctx = context, args = argArrays.toSeq, argsGrad = gradArrayMap.toMap, 711 gradsReq = gradReqRun, auxStates = auxArrays.toSeq, group2ctx = null, 712 sharedExec = sharedExec.orNull) 713 } 714 715 /** 716 * Get the sliced shapes for the i-th executor. 717 * @param shapes : The original (name, shape) pairs. 718 * @param i Which executor we are dealing with. 719 * @param majorAxis 720 */ 721 private def slicedShape(shapes: Seq[DataDesc], i: Int, majorAxis: Seq[Int]) 722 : Seq[(String, Shape)] = { 723 (shapes zip majorAxis).map { case (DataDesc(k, shape, _ , _), axis) => 724 val shapeArr = shape.toArray 725 if (axis >= 0) { 726 shapeArr(axis) = slices(i)._2 - slices(i)._1 727 } 728 (k, Shape(shapeArr)) 729 } 730 } 731 732 // Install monitor on all executors 733 def installMonitor(monitor: Monitor): Unit = { 734 execs.foreach(monitor.install) 735 } 736 737 // Internal helper to get a memory block or re-use by re-shaping 738 private def getOrReshape(name: String, 739 sharedDataArrays: mutable.Map[String, NDArray], 740 argShape: Shape, 741 argType: DType, 742 context: Context): NDArray = { 743 if (sharedDataArrays.contains(name)) { 744 val argArr = sharedDataArrays(name) 745 if (argArr.shape.product >= argShape.product) { 746 // nice, we can directly re-use this data blob 747 require(argArr.dtype == argType, 748 s"Type ${argArr.dtype} of argument $name does not match infered type ${argType}") 749 argArr.reshape(argShape) 750 } else { 751 DataParallelExecutorGroup.logger.warn(s"bucketing: data $name has a shape $argShape," + 752 s"which is larger than already allocated shape ${argArr.shape}." + 753 "Need to re-allocate. Consider putting default_bucket_key to be the bucket" + 754 "taking the largest input for better memory sharing.") 755 val argArrNew = NDArray.zeros(argShape, context, dtype = argType) 756 // replace existing shared array because the new one is bigger 757 sharedDataArrays.put(name, argArrNew) 758 argArrNew 759 } 760 } else { 761 val argArrNew = NDArray.zeros(argShape, context, dtype = argType) 762 sharedDataArrays.put(name, argArrNew) 763 argArrNew 764 } 765 } 766} 767