1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnet.module
19
20import org.apache.mxnet.DType.DType
21import org.apache.mxnet._
22import org.apache.mxnet.module.DataParallelExecutorGroup.Builder
23import org.slf4j.{Logger, LoggerFactory}
24
25import scala.collection.mutable
26import scala.collection.mutable.ArrayBuffer
27
28private object DataParallelExecutorGroup {
29  private val logger: Logger = LoggerFactory.getLogger(classOf[DataParallelExecutorGroup])
30  // Load a list of arrays into a list of arrays specified by slices
31  private def loadGeneralMulti(data: Seq[NDArray],
32                               targets: Seq[Array[((Int, Int), NDArray)]],
33                               majorAxis: Seq[Int]): Unit = {
34    for (((dSrc, dTargets), axis) <- data zip targets zip majorAxis) {
35      for (((sliceIdxStart, sliceIdxStop), dDst) <- dTargets) {
36        if (axis >= 0 && (sliceIdxStart > 0 || sliceIdxStop < dSrc.shape(axis))) {
37          // copy slice
38          val shape = dSrc.shape
39          val begin = Array.fill(shape.length)(0)
40          val end = shape.toArray
41          begin(axis) = sliceIdxStart
42          end(axis) = sliceIdxStop
43          if (dSrc.context == dDst.context) {
44            NDArray.crop(Map(
45              "begin" -> new Shape(begin),
46              "end" -> new Shape(end),
47              "out" -> dDst))(dSrc)
48          } else {
49            // on different device, crop and then do cross device copy
50            val dDstCopy: NDArray = NDArray.crop(Map(
51              "begin" -> new Shape(begin),
52              "end" -> new Shape(end)))(dSrc)
53            dDstCopy.copyTo(dDst)
54          }
55        } else {
56          dSrc.copyTo(dDst)
57        }
58      }
59    }
60  }
61
62  private def loadGeneral(data: Seq[NDArray], targets: Seq[NDArray]): Unit = {
63    for ((dSrc, dTarget) <- data zip targets) {
64      dSrc.copyTo(dTarget)
65    }
66  }
67
68  // Load data into sliced arrays
69  private def loadData(batch: DataBatch,
70                       targets: Seq[Array[((Int, Int), NDArray)]],
71                       majorAxis: Seq[Int]): Unit = {
72    loadGeneralMulti(batch.data, targets, majorAxis)
73  }
74
75
76  // Load label into sliced arrays
77  private def loadLabel(batch: DataBatch,
78                        targets: Seq[Array[((Int, Int), NDArray)]],
79                        majorAxis: Seq[Int]): Unit = {
80    loadGeneralMulti(batch.label, targets, majorAxis)
81  }
82
83  // Merge outputs that lives on multiple context into one,
84  // so that they look like living on one context.
85  private def mergeMultiContext(outputs: IndexedSeq[IndexedSeq[NDArray]], majorAxis: Seq[Int])
86    : IndexedSeq[NDArray] = {
87    (outputs zip majorAxis).map { case (tensors, axis) =>
88      if (axis >= 0) {
89        NDArray.concatenate(tensors, axis = axis, alwaysCopy = false)
90      } else {
91        // negative axis means the there is no batch_size axis, and all the
92        // results should be the same on each device. We simply take the first one,
93        // without checking they are actually the same
94        tensors(0)
95      }
96    }
97  }
98
99  private object Builder {
100    private[module] def convertGradReq(
101        gradReq: String, argNames: IndexedSeq[String], paramNames: IndexedSeq[String],
102        fixedParamNames: Set[String], dataNames: Seq[String], inputsNeedGrad: Boolean)
103        : Map[String, String] = {
104      require(argNames != null, "Invalid argNames")
105      require(paramNames != null, "Invalid paramNames")
106      require(fixedParamNames != null, "Invalid fixedParamNames")
107      require(dataNames != null, "Invalid dataNames")
108      argNames.map(k => {
109        if (paramNames.contains(k)) {
110          (k, if (fixedParamNames.contains(k)) "null" else gradReq)
111        } else if (dataNames.contains(k)) {
112          (k, if (inputsNeedGrad) gradReq else "null")
113        } else {
114          (k, "null")
115        }
116      }).toMap
117    }
118  }
119
120  class Builder private[module](private val symbol: Symbol,
121                                private val contexts: Array[Context],
122                                private val paramNames: IndexedSeq[String]) {
123
124    private var workLoadList: IndexedSeq[Float] = null
125    private var dataShapes: IndexedSeq[DataDesc] = null
126    private var labelShapes: Option[IndexedSeq[DataDesc]] = None
127    private var forTraining: Boolean = true
128    private var inputsNeedGrad: Boolean = false
129    private var sharedGroup: Option[DataParallelExecutorGroup] = None
130    private var inputTypes: Option[Map[String, DType]] = None
131    private var fixedParamNames: Set[String] = Set.empty[String]
132    private var gradReqs: Map[String, String] = null
133
134    val argNames = symbol.listArguments()
135
136    def setWorkLoadList(workLoad: IndexedSeq[Float]): Builder = {
137      this.workLoadList = workLoad
138      this
139    }
140
141    def setDataShapes(shapes: IndexedSeq[DataDesc]): Builder = {
142      require(shapes != null, "Invalid shapes")
143      this.dataShapes = shapes
144      this
145    }
146
147    def setDataShapesByName(shapes: IndexedSeq[(String, Shape)]): Builder = {
148      require(shapes != null, "Invalid shapes")
149      this.dataShapes = shapes.map { case (k, s) => new DataDesc(k, s) }
150      this
151    }
152
153    def setLabelShapes(shapes: IndexedSeq[DataDesc]): Builder = {
154      this.labelShapes = Option(shapes)
155      this
156    }
157
158    def setLabelShapesByName(shapes: IndexedSeq[(String, Shape)]): Builder = {
159      this.labelShapes = Option(shapes).map(shapesInst =>
160        shapesInst.map { case (k, s) => new DataDesc(k, s) }
161      )
162      this
163    }
164
165    def setForTraining(forTraining: Boolean): Builder = {
166      this.forTraining = forTraining
167      this
168    }
169
170    def setInputsNeedGrad(needGrad: Boolean): Builder = {
171      this.inputsNeedGrad = needGrad
172      this
173    }
174
175    def setSharedGroup(sharedGroup: DataParallelExecutorGroup): Builder = {
176      this.sharedGroup = Option(sharedGroup)
177      this
178    }
179
180    def setInputTypes(inputTypes: Map[String, DType]): Builder = {
181      this.inputTypes = Option(inputTypes)
182      this
183    }
184
185    def setFixedParamNames(fixedParamNames: Set[String]): Builder = {
186      this.fixedParamNames = Option(fixedParamNames).getOrElse(Set.empty[String])
187      this
188    }
189
190    def setGradReq(gradReq: Map[String, String]): Builder = {
191      require(dataShapes != null, "dataShapes must be set first")
192      val gradReqTmp = mutable.HashMap.empty[String, String]
193      val dataNames = dataShapes.map(_.name)
194      for (k <- argNames) {
195        if (paramNames.contains(k)) {
196          gradReqTmp.put(k, if (fixedParamNames.contains(k)) "null" else "write")
197        } else if (dataNames.contains(k)) {
198          gradReqTmp.put(k, if (inputsNeedGrad) "write" else "null")
199        } else {
200          gradReqTmp.put(k, "null")
201          gradReqTmp ++= gradReq
202        }
203      }
204      this.gradReqs = gradReqTmp.toMap
205      this
206    }
207
208    def setGradReq(gradReq: String): Builder = {
209      require(dataShapes != null, "dataShapes must be set first")
210      val dataNames = dataShapes.map(_.name)
211      this.gradReqs = Builder.convertGradReq(
212        gradReq, argNames, paramNames, fixedParamNames, dataNames, inputsNeedGrad)
213      this
214    }
215
216    def setGradReq(gradReq: Seq[(String, String)]): Builder = {
217      require(gradReq.size == argNames.size,
218        s"provided number of gradReq (${gradReq.size}) do not match number of args " +
219          s"(${argNames.size})")
220      this.gradReqs = gradReq.toMap
221      this
222    }
223
224    def build(): DataParallelExecutorGroup = {
225      new DataParallelExecutorGroup(
226        symbol, contexts, workLoadList, dataShapes, labelShapes, paramNames, forTraining,
227        inputsNeedGrad, sharedGroup, inputTypes, fixedParamNames, this.gradReqs)
228    }
229  }
230}
231
232/**
233 * DataParallelExecutorGroup is a group of executors that lives on a group of devices.
234 * This is a helper class used to implement data parallelism. Each mini-batch will
235 * be split and run on the devices.
236 * @param symbol The common symbolic computation graph for all executors.
237 * @param contexts A list of contexts.
238 * @param workLoadList If not `None`, could be a list of numbers that
239 *                     specify the workload to be assigned to different context.
240 *                     Larger number indicate heavier workload.
241 * @param dataShapes Should be a list of (name, shape) tuples, for the shapes of data.
242 *                   Note the order is important and should be the same as the order that
243 *                   the `DataIter` provide the data.
244 * @param labelShapes Should be a list of (name, shape) tuples, for the shapes of label.
245 *                    Note the order is important and should be the same as the order that
246 *                    the `DataIter` provide the label.
247 * @param paramNames A list of strings, indicating the names of parameters
248 *                   (e.g. weights, filters, etc.) in the computation graph.
249 * @param forTraining Indicate whether the executors should be bind for training.
250 *                    When not doing training, the memory for gradients will not be allocated.
251 * @param inputsNeedGrad Indicate whether the gradients for the input data should be computed.
252 *                       This is currently not used.
253 *                       It will be useful for implementing composition of modules.
254 * @param sharedGroup Default is `None`. This is used in bucketing. When not `None`,
255 *                    it should be a executor group corresponding to a different bucket.
256 *                    In other words, it will correspond to a different symbol but
257 *                    with the same set of parameters (e.g. unrolled RNNs with different lengths).
258 *                    In this case, many memory will be shared.
259 * @param inputTypes Default is `None`. When not `None`,
260 *                   can be used to specify the data type for each of the data/label inputs.
261 * @param fixedParamNames Indicate parameters to be fixed during training.
262 *                        Parameters in this list will not allocate space for gradient,
263 *                        nor do gradient calculation.
264 * @param gradReq Requirement for gradient accumulation. Can be 'write', 'add', or 'null',
265 *                be specified for each argument.
266 */
267class DataParallelExecutorGroup private[module](
268    symbol: Symbol,
269    contexts: Array[Context],
270    workLoadList: IndexedSeq[Float],
271    var dataShapes: IndexedSeq[DataDesc],
272    var labelShapes: Option[IndexedSeq[DataDesc]] = None,
273    private[module] val paramNames: IndexedSeq[String],
274    forTraining: Boolean,
275    inputsNeedGrad: Boolean,
276    sharedGroup: Option[DataParallelExecutorGroup] = None,
277    inputTypes: Option[Map[String, DType]] = None,
278    fixedParamNames: Set[String] = Set.empty[String],
279    gradReq: Map[String, String] = null) {
280
281  require(symbol != null, "Undefined symbol")
282  require(contexts != null, "Undefined context")
283
284  private val argNames = symbol.listArguments()
285  private val auxNames = symbol.listAuxiliaryStates()
286
287  private val gradReqRun =
288    if (!forTraining) {
289      val dataNames = dataShapes.map(_.name)
290      Builder.convertGradReq("null",
291        argNames, paramNames, fixedParamNames, dataNames, inputsNeedGrad)
292    } else {
293      gradReq
294    }
295
296  private val sharedDataArrays: Array[mutable.Map[String, NDArray]] =
297    sharedGroup.map(_.sharedDataArrays).getOrElse(
298    Array.fill(contexts.length)(mutable.Map.empty[String, NDArray]))
299
300  private var batchSize: Int = -1
301  private var slices: Array[(Int, Int)] = null
302  private var execs: Array[Executor] = null
303  private var dataArrays: Seq[Array[((Int, Int), NDArray)]] = null
304  private var labelArrays: Option[Seq[Array[((Int, Int), NDArray)]]] = None
305  private[module] var paramArrays: IndexedSeq[Array[NDArray]] = null
306  private[module] var gradArrays: IndexedSeq[Array[NDArray]] = null
307  private[module] var auxArrays: IndexedSeq[Array[NDArray]] = null
308  private var inputGradArrays: IndexedSeq[Array[NDArray]] = null
309
310  private var dataLayouts = decideSlices(dataShapes)
311  private var labelLayouts =
312    // call it to make sure labels has the same batch size as data
313    if (labelShapes != None) decideSlices(labelShapes.get)
314    else null
315
316  private val outputLayouts = symbol.listOutputs().map(name => {
317    val sym = symbol.get(name)
318    val layout = sym.attr("__layout__")
319    sym.dispose()
320    DataDesc.getBatchAxis(layout)
321  }
322  )
323  bindExec(dataShapes, labelShapes, sharedGroup)
324
325  def getBatchSize: Int = batchSize
326
327  /**
328   * Decide the slices for each context according to the workload.
329   * @param dataShapes list of DataDesc(name, shape) specifying
330   *                   the shapes for the input data or label.
331   */
332  private def decideSlices(dataShapes: Seq[DataDesc]): Seq[Int] = {
333    require(dataShapes.size > 0, "dataShapes must be non empty")
334    val majorAxis = dataShapes.map(data => DataDesc.getBatchAxis(Option(data.layout)))
335
336    for ((dataDesc, axis) <- dataShapes.zip(majorAxis)) {
337      if (axis != -1) {
338        val batchSize = dataDesc.shape(axis)
339        if (this.batchSize != -1) {
340          require(batchSize == this.batchSize,
341            s"all data must have the same batch size: $batchSize," +
342            s"but ${dataDesc.name} has shape ${dataDesc.shape}")
343        } else {
344          this.batchSize = batchSize
345          require(this.workLoadList != null, "Undefined workLoadList")
346          this.slices = ExecutorManager.splitInputSlice(this.batchSize, this.workLoadList)
347        }
348      }
349    }
350    majorAxis
351  }
352
353  /**
354   * Bind executors on their respective devices.
355   * @param dataShapes DataDesc for input data.
356   * @param labelShapes DataDesc for input labels.
357   * @param sharedGroup
358   * @param reshape
359   */
360  def bindExec(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]],
361               sharedGroup: Option[DataParallelExecutorGroup], reshape: Boolean = false): Unit = {
362    this.batchSize = -1
363    dataLayouts = decideSlices(dataShapes)
364    labelLayouts = {
365      // call it to make sure labels has the same batch size as data
366      if (labelShapes != None) decideSlices(labelShapes.get)
367      else null
368    }
369    if (reshape) {
370      (0 until contexts.length).foreach { i =>
371        val dataShapesSliced = slicedShape(dataShapes, i, dataLayouts)
372        val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts))
373        val inputShapes
374          = dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape])
375
376        ResourceScope.usingIfScopeExists(execs(i).scope) {
377          val tmpExec = execs(i).reshape(allowUpSizing = true, kwargs = inputShapes)
378          execs(i).dispose()
379          execs(i) = tmpExec
380        }
381      }
382    } else {
383      execs = (0 until contexts.length).map(i =>
384        bindIthExec(i, dataShapes, labelShapes, sharedGroup)
385      ).toArray
386    }
387
388    this.dataShapes = dataShapes
389    this.labelShapes = labelShapes
390
391    // convenient data structures
392    dataArrays = dataShapes.map(dataDesc =>
393      this.execs.zipWithIndex.map { case (e, i) => (this.slices(i), e.argDict(dataDesc.name)) }
394    )
395
396    labelArrays = labelShapes.map(shapes =>
397      shapes.map(labelDesc =>
398        this.execs.zipWithIndex.map { case (e, i) => (this.slices(i), e.argDict(labelDesc.name)) }
399      )
400    )
401
402    paramArrays = argNames.zipWithIndex.withFilter {
403      case (name, i) => paramNames.contains(name)
404    }.map { case (name, i) =>
405      execs.map(_.argArrays(i))
406    }
407
408    gradArrays =
409      if (forTraining) {
410        argNames.zipWithIndex.withFilter {
411          case (name, i) => paramNames.contains(name)
412        }.map { case (name, i) =>
413          execs.map(_.gradArrays(i))
414        }
415      } else {
416        null
417      }
418
419    val dataNames = dataShapes.map(_.name)
420    inputGradArrays =
421      if (inputsNeedGrad) {
422        argNames.zipWithIndex.withFilter {
423          case (name, i) => dataNames.contains(name)
424        }.map { case (name, i) =>
425          execs.map(_.gradArrays(i))
426        }
427      } else {
428        null
429      }
430
431    auxArrays = (0 until auxNames.length).map(i => execs.map(_.auxArrays(i)))
432  }
433
434  /**
435   * Reshape executors.
436   * @param dataShapes
437   * @param labelShapes
438   */
439  def reshape(dataShapes: IndexedSeq[DataDesc], labelShapes: Option[IndexedSeq[DataDesc]]): Unit = {
440    if (!(dataShapes == this.dataShapes && labelShapes == this.labelShapes)) {
441      this.bindExec(dataShapes, labelShapes, None, reshape = true)
442    }
443  }
444
445  /**
446   * Assign, i.e. copy parameters to all the executors.
447   * @param argParams A dictionary of name to `NDArray` parameter mapping.
448   * @param auxParams A dictionary of name to `NDArray` auxiliary variable mapping.
449   * @param allowExtra hether allow extra parameters that are not needed by symbol.
450   *         If this is True, no error will be thrown when argParams or auxParams
451   *         contain extra parameters that is not needed by the executor.
452   */
453  def setParams(argParams: Map[String, NDArray], auxParams: Map[String, NDArray],
454    allowExtra: Boolean = false): Unit = {
455    execs.foreach(_.copyParamsFrom(argParams, auxParams, allowExtraParams = allowExtra))
456  }
457
458  /**
459   * Copy data from each executor to `arg_params` and `aux_params`.
460   * @param argParams target parameter arrays
461   * @param auxParams target aux arrays
462   * Note this function will inplace update the NDArrays in arg_params and aux_params.
463   */
464  def getParams(argParams: Map[String, NDArray], auxParams: Map[String, NDArray]): Unit = {
465    for ((name, block) <- paramNames.zip(paramArrays)) {
466      val weight = (block.map(_.copyTo(Context.cpu())).reduce((a: NDArray, b: NDArray) =>
467        (a + b).disposeDeps()
468      ) / block.length).disposeDeps()
469      val weightNewType = weight.asType(argParams(name).dtype)
470      weightNewType.copyTo(argParams(name))
471      weight.dispose()
472      weightNewType.dispose()
473    }
474    for ((name, block) <- auxNames.zip(auxArrays)) {
475      val weight = (block.map(_.copyTo(Context.cpu())).reduce((a: NDArray, b: NDArray) =>
476        (a + b).disposeDeps()
477      ) / block.length).disposeDeps()
478      val weightNewType = weight.asType(auxParams(name).dtype)
479      weightNewType.copyTo(auxParams(name))
480      weight.dispose()
481      weightNewType.dispose()
482    }
483  }
484
485  /**
486   * Split `dataBatch` according to workload and run forward on each devices.
487   * @param dataBatch
488   * @param isTrain The hint for the backend, indicating whether we are during training phase.
489   *                Default is `None`, then the value `self.for_training` will be used.
490   */
491  def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit = {
492    DataParallelExecutorGroup.loadData(dataBatch, dataArrays, dataLayouts)
493    val isTrainOpt = isTrain.getOrElse(this.forTraining)
494    labelArrays.foreach(labels => {
495      require(!isTrainOpt || dataBatch.label != null, "label must be defined if in training phase")
496      if (dataBatch.label != null) {
497        require(labelLayouts != null, "label layouts are undefined")
498        DataParallelExecutorGroup.loadLabel(dataBatch, labels, labelLayouts)
499      }
500    })
501    execs.foreach(_.forward(isTrainOpt))
502  }
503
504  // Get the shapes of the outputs.
505  def getOutputShapes: IndexedSeq[(String, Shape)] = {
506    val outputs = execs(0).outputs
507    val shapes = outputs.map(_.shape)
508    (symbol.listOutputs() zip shapes zip outputLayouts) map { case ((key, theShape), axis) =>
509      val shape = theShape.toArray
510      if (axis >= 0) {
511        shape(axis) = batchSize
512      }
513      (key, Shape(shape))
514    }
515  }
516
517  /**
518   * Get outputs of the previous forward computation.
519   * @return In the case when data-parallelism is used,
520   *         the outputs will be collected from multiple devices.
521   *         The results will look like `[ [out1_dev1, out1_dev2], [out2_dev1, out2_dev2] ]`,
522   *         those `NDArray` might live on different devices.
523   */
524  def getOutputs(): IndexedSeq[IndexedSeq[NDArray]] = {
525    (0 until execs(0).outputs.length).map(i => execs.map(_.outputs(i)).toIndexedSeq)
526  }
527
528  /**
529   * Get outputs of the previous forward computation.
530   * @return In the case when data-parallelism is used,
531   *         the outputs will be merged from multiple devices,
532   *         as they look like from a single executor.
533   *         The results will look like `[out1, out2]`
534   */
535  def getOutputsMerged(): IndexedSeq[NDArray] = {
536    DataParallelExecutorGroup.mergeMultiContext(getOutputs(), outputLayouts)
537  }
538
539  /**
540   * Get the gradients to the inputs, computed in the previous backward computation.
541   * @return In the case when data-parallelism is used,
542   *         the grads will be collected from multiple devices.
543   *         The results will look like `[ [grad1_dev1, grad1_dev2], [grad2_dev1, grad2_dev2] ]`,
544   *         those `NDArray` might live on different devices.
545   */
546  def getInputGrads(): IndexedSeq[IndexedSeq[NDArray]] = {
547    require(inputsNeedGrad, "Cannot get InputGrads when inputNeedGrad is set to false")
548    inputGradArrays.map(_.toIndexedSeq)
549  }
550
551  /**
552   * Get the gradients to the inputs, computed in the previous backward computation.
553   * @return In the case when data-parallelism is used,
554   *         the grads will be merged from multiple devices,
555   *         as they look like from a single executor.
556   *         The results will look like `[grad1, grad2]`
557   */
558  def getInputGradsMerged(): IndexedSeq[NDArray] = {
559    DataParallelExecutorGroup.mergeMultiContext(getInputGrads(), dataLayouts)
560  }
561
562  /**
563   * Run backward on all devices. A backward should be called after
564   * a call to the forward function. Backward cannot be called unless
565   * `this.for_training` is `True`.
566   * @param outGrads Gradient on the outputs to be propagated back.
567   *                 This parameter is only needed when bind is called
568   *                 on outputs that are not a loss function.
569   */
570  def backward(outGrads: Array[NDArray] = null): Unit = {
571    require(forTraining, "re-bind with forTraining = true to run backward")
572
573    for (((exec, islice), i) <- (execs zip slices).zipWithIndex) {
574      val outGradsSlice =
575        if (outGrads != null) {
576          (outGrads zip outputLayouts).map { case (grad, axis) =>
577            if (axis >= 0) {
578              val ogMySlice: NDArray = NDArray.slice_axis(
579                Map("axis" -> axis, "begin" -> islice._1, "end" -> islice._2))(grad)
580              ogMySlice.asInContext(contexts(i))
581            } else {
582              grad.copyTo(contexts(i))
583            }
584          }
585        } else {
586          Array.empty[NDArray]
587        }
588      exec.backward(outGrads = outGradsSlice)
589    }
590  }
591
592  /**
593   * Accumulate the performance according to `eval_metric` on all devices.
594   * @param evalMetric The metric used for evaluation.
595   * @param labels Typically comes from `label` of a `DataBatch`.
596   */
597  def updateMetric(evalMetric: EvalMetric, labels: IndexedSeq[NDArray]): Unit = {
598    for ((texec, islice) <- this.execs zip this.slices) {
599      val labelsSlice =
600        (labels zip this.labelLayouts) map { case (label, axis) =>
601          if (axis == 0) {
602            label.slice(islice)
603          } else if (axis > 0) {
604            val labelMySlice: NDArray = NDArray.slice_axis(Map(
605              "axis" -> axis, "begin" -> islice._1, "end" -> islice._2))(label)
606              .asInContext(label.context)
607            labelMySlice
608          } else {
609            label
610          }
611        }
612
613      evalMetric.update(labelsSlice, texec.outputs)
614
615      // Clear up any slices we created (sometimes we don't slice so check for this)
616      (labels zip labelsSlice).foreach { case (label, labelSlice) =>
617        if (label ne labelSlice) {
618          labelSlice.dispose()
619        }
620      }
621    }
622  }
623
624  // Internal utility function to bind the i-th executor.
625  private def bindIthExec(i: Int, dataShapes: Seq[DataDesc],
626                          labelShapes: Option[Seq[DataDesc]],
627                          sharedGroup: Option[DataParallelExecutorGroup]): Executor = {
628    val dataShapesSliced = slicedShape(dataShapes, i, dataLayouts)
629    val labelShapesSliced = labelShapes.map(slicedShape(_, i, labelLayouts))
630    val sharedExec = sharedGroup.map(_.execs(i))
631    val context = contexts(i)
632    val sharedDataArrays = this.sharedDataArrays(i)
633
634    val inputShapes
635      = dataShapesSliced.toMap ++ labelShapesSliced.getOrElse(Map.empty[String, Shape])
636
637    val (argShapes, _, auxShapes) = symbol.inferShape(inputShapes)
638    require(argShapes != null, "Shape inference failed." +
639      s"Known shapes are $inputShapes for symbol arguments ${symbol.listArguments()} " +
640      s"and aux states ${symbol.listAuxiliaryStates()}")
641
642    val inputTypesGot = inputTypes.getOrElse(inputShapes.map { case (k, v) =>
643      (k, Base.MX_REAL_TYPE)
644    })
645    val (argTypes, _, auxTypes) = symbol.inferType(inputTypesGot)
646    require(argTypes != null, "Type inference failed." +
647      s"Known types as $inputTypes for symbol arguments ${symbol.listArguments()} " +
648      s"and aux states ${symbol.listAuxiliaryStates()}")
649
650    val argArrays = ArrayBuffer.empty[NDArray]
651    val gradArrayMap = mutable.HashMap.empty[String, NDArray]
652
653    // create or borrow arguments and gradients
654    for (j <- 0 until argNames.length) {
655      val name = argNames(j)
656      val argArr =
657        if (paramNames.contains(name)) {
658          // model parameter
659          sharedExec match {
660            case None =>
661              val argArr = NDArray.zeros(argShapes(j), context, dtype = argTypes(j))
662              if (gradReqRun(name) != "null") {
663                val gradArr = NDArray.zeros(argShapes(j), context, dtype = argTypes(j))
664                gradArrayMap.put(name, gradArr)
665              }
666              argArr
667            case Some(sharedExecInst) =>
668              val argArr = sharedExecInst.argDict(name)
669              require(argArr.shape == argShapes(j),
670                s"Shape ${argArr.shape} of argument $name does not match " +
671                  s"inferred shape ${argShapes(j)}")
672              require(argArr.dtype == argTypes(j),
673                s"Type ${argArr.dtype} of argument $name does not match " +
674                  s"inferred type ${argTypes(j)}")
675              if (gradReqRun(name) != "null") {
676                gradArrayMap.put(name, sharedExecInst.gradDict(name))
677              }
678              argArr
679          }
680        } else {
681          // data or label
682          val argArr = getOrReshape(name, sharedDataArrays, argShapes(j), argTypes(j), context)
683          // data might also need grad if inputs_need_grad is True
684          if (gradReqRun(name) != "null") {
685            gradArrayMap.put(name,
686              getOrReshape(s"grad of $name", sharedDataArrays, argShapes(j), argTypes(j), context))
687          }
688          argArr
689        }
690      argArrays.append(argArr)
691    }
692
693    // create or borrow aux variables
694    val auxArrays =
695      sharedExec match {
696        case None => (auxShapes zip auxTypes).map { case (s, t) =>
697          NDArray.zeros(s, context, dtype = t)
698        }.toArray
699        case Some(sharedExecInst) =>
700          for ((arr, j) <- sharedExecInst.auxArrays.zipWithIndex) {
701            require(auxShapes(j) == arr.shape,
702              s"Shape ${arr.shape} of aux variable ${auxNames(j)} does not match " +
703                s"inferred shape ${auxShapes(j)}")
704            require(auxTypes(j) == arr.dtype,
705              s"Type ${arr.dtype} of aux variable ${auxNames(j)} does not match " +
706                s"inferred type ${auxTypes(j)}")
707          }
708          sharedExecInst.auxArrays.map(identity)
709      }
710    symbol.bind(ctx = context, args = argArrays.toSeq, argsGrad = gradArrayMap.toMap,
711      gradsReq = gradReqRun, auxStates = auxArrays.toSeq, group2ctx = null,
712      sharedExec = sharedExec.orNull)
713  }
714
715  /**
716   * Get the sliced shapes for the i-th executor.
717   * @param shapes : The original (name, shape) pairs.
718   * @param i Which executor we are dealing with.
719   * @param majorAxis
720   */
721  private def slicedShape(shapes: Seq[DataDesc], i: Int, majorAxis: Seq[Int])
722    : Seq[(String, Shape)] = {
723    (shapes zip majorAxis).map { case (DataDesc(k, shape, _ , _), axis) =>
724      val shapeArr = shape.toArray
725      if (axis >= 0) {
726        shapeArr(axis) = slices(i)._2 - slices(i)._1
727      }
728      (k, Shape(shapeArr))
729    }
730  }
731
732  // Install monitor on all executors
733  def installMonitor(monitor: Monitor): Unit = {
734    execs.foreach(monitor.install)
735  }
736
737  // Internal helper to get a memory block or re-use by re-shaping
738  private def getOrReshape(name: String,
739                           sharedDataArrays: mutable.Map[String, NDArray],
740                           argShape: Shape,
741                           argType: DType,
742                           context: Context): NDArray = {
743    if (sharedDataArrays.contains(name)) {
744      val argArr = sharedDataArrays(name)
745      if (argArr.shape.product >= argShape.product) {
746        // nice, we can directly re-use this data blob
747        require(argArr.dtype == argType,
748          s"Type ${argArr.dtype} of argument $name does not match infered type ${argType}")
749        argArr.reshape(argShape)
750      } else {
751        DataParallelExecutorGroup.logger.warn(s"bucketing: data $name has a shape $argShape," +
752          s"which is larger than already allocated shape ${argArr.shape}." +
753          "Need to re-allocate. Consider putting default_bucket_key to be the bucket" +
754          "taking the largest input for better memory sharing.")
755        val argArrNew = NDArray.zeros(argShape, context, dtype = argType)
756        // replace existing shared array because the new one is bigger
757        sharedDataArrays.put(name, argArrNew)
758        argArrNew
759      }
760    } else {
761      val argArrNew = NDArray.zeros(argShape, context, dtype = argType)
762      sharedDataArrays.put(name, argArrNew)
763      argArrNew
764    }
765  }
766}
767