1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnetexamples.neuralstyle.end2end
19
20import org.apache.mxnet.{Context, Initializer, NDArray, Optimizer, Shape, Symbol, Uniform}
21import org.slf4j.LoggerFactory
22
23class Module(symbol: Symbol,
24             context: Context,
25             dataShapes: Map[String, Shape],
26             labelShapes: Map[String, Shape] = Map[String, Shape](),
27             initializer: Initializer = new Uniform(0.01f),
28             forTraining: Boolean = true,
29             inputsNeedGrad: Boolean = false) {
30
31  private val logger = LoggerFactory.getLogger(classOf[Module])
32
33  private val dataLabelShape = dataShapes ++ labelShapes
34  private val (argDict, gradDict, auxDict) = {
35    val (argShapes, outShapes, auxShapes) = symbol.inferShape(dataLabelShape)
36    val argNames = symbol.listArguments()
37    val argDict = argNames.zip(argShapes.map(NDArray.empty(_, context))).toMap
38
39    val filterShapes = if (inputsNeedGrad) labelShapes else dataLabelShape
40    val gradDict = argNames.zip(argShapes).filter { case (name, shape) =>
41      !filterShapes.contains(name)
42    }.map(x => x._1 -> NDArray.empty(x._2, context) ).toMap
43
44    val auxDict = symbol.listAuxiliaryStates().zip(auxShapes.map(NDArray.empty(_, context))).toMap
45
46    (argDict, gradDict, auxDict)
47  }
48
49  private val dataArrs = dataShapes.keys.toArray.map(argDict(_))
50  private val labelArrs = labelShapes.keys.toArray.map(argDict(_))
51  private val dataGrads = {
52    if (inputsNeedGrad) dataShapes.keys.toArray.map(gradDict(_))
53    else null
54  }
55
56  argDict.foreach { case (name, ndArray) =>
57    if (!dataLabelShape.contains(name)) initializer(name, ndArray)
58  }
59
60  private val executor = symbol.bind(context, argDict, gradDict, "write", auxDict, null, null)
61
62  private var optimizer: Optimizer = null
63  private var paramsGrads: List[(Int, String, NDArray, AnyRef)] = null
64  private var optimizerInitialized: Boolean = false
65
66  def initOptimizer(opt: Optimizer): Unit = {
67    this.optimizer = opt
68    this.paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) =>
69      (idx, name, grad, this.optimizer.createState(idx, argDict(name)))
70    }
71    this.optimizerInitialized = true
72  }
73
74  def forward(datas: Array[NDArray], labels: Array[NDArray] = Array[NDArray]()): Unit = {
75    datas.zip(this.dataArrs).foreach { case (src, dest) => dest.set(src) }
76    labels.zip(this.labelArrs).foreach { case (src, dest) => dest.set(src) }
77    this.executor.forward(isTrain = forTraining)
78  }
79
80  def backward(outGrads: Array[NDArray]): Unit = {
81    this.executor.backward(outGrads)
82  }
83
84  def update(): Unit = {
85    assert(this.optimizerInitialized)
86    paramsGrads.foreach { case (idx, name, grad, optimState) =>
87      this.optimizer.update(idx, argDict(name), grad, optimState)
88    }
89  }
90
91  def dispose(): Unit = {
92    this.executor.dispose()
93    this.argDict.foreach(_._2.dispose())
94    this.gradDict.foreach(_._2.dispose())
95    this.auxDict.foreach(_._2.dispose())
96  }
97
98  def setParams(params: Map[String, NDArray]): Unit = {
99    params.foreach { case (name, arr) =>
100      if (this.argDict.contains(name)) {
101        this.argDict(name).set(arr)
102      }
103      else if (this.auxDict.contains(name)) {
104        this.auxDict(name).set(arr)
105      }
106      else logger.info(name)
107    }
108  }
109
110  def loadParams(fName: String): Unit = {
111    val saveDict = NDArray.load2Map(fName)
112    var params = Map[String, NDArray]()
113    saveDict.foreach { case (k, v) =>
114      val (argType, name) = {
115        val tmp = k.split(":")
116        (tmp(0), tmp(1))
117      }
118      if (argType == "arg" || argType == "aux") {
119        params += name -> v
120      }
121    }
122    this.setParams(params)
123  }
124
125  def saveParams(fName: String): Unit = {
126    val saveDict = {
127      argDict.filter(x => !dataLabelShape.contains(x._1))
128      .map { case (k, v) => s"arg:$k" -> v } ++
129      auxDict.map { case (k, v) => s"aux:$k" -> v }
130    }
131    NDArray.save(fName, saveDict)
132  }
133
134  def getOutputs(): Array[NDArray] = this.executor.outputs
135
136  def getInputGrads(): Array[NDArray] = this.dataGrads
137}
138