1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnetexamples.neuralstyle.end2end 19 20import org.apache.mxnet.{Context, Initializer, NDArray, Optimizer, Shape, Symbol, Uniform} 21import org.slf4j.LoggerFactory 22 23class Module(symbol: Symbol, 24 context: Context, 25 dataShapes: Map[String, Shape], 26 labelShapes: Map[String, Shape] = Map[String, Shape](), 27 initializer: Initializer = new Uniform(0.01f), 28 forTraining: Boolean = true, 29 inputsNeedGrad: Boolean = false) { 30 31 private val logger = LoggerFactory.getLogger(classOf[Module]) 32 33 private val dataLabelShape = dataShapes ++ labelShapes 34 private val (argDict, gradDict, auxDict) = { 35 val (argShapes, outShapes, auxShapes) = symbol.inferShape(dataLabelShape) 36 val argNames = symbol.listArguments() 37 val argDict = argNames.zip(argShapes.map(NDArray.empty(_, context))).toMap 38 39 val filterShapes = if (inputsNeedGrad) labelShapes else dataLabelShape 40 val gradDict = argNames.zip(argShapes).filter { case (name, shape) => 41 !filterShapes.contains(name) 42 }.map(x => x._1 -> NDArray.empty(x._2, context) ).toMap 43 44 val auxDict = symbol.listAuxiliaryStates().zip(auxShapes.map(NDArray.empty(_, context))).toMap 45 46 (argDict, gradDict, auxDict) 47 } 48 49 private val dataArrs = dataShapes.keys.toArray.map(argDict(_)) 50 private val labelArrs = labelShapes.keys.toArray.map(argDict(_)) 51 private val dataGrads = { 52 if (inputsNeedGrad) dataShapes.keys.toArray.map(gradDict(_)) 53 else null 54 } 55 56 argDict.foreach { case (name, ndArray) => 57 if (!dataLabelShape.contains(name)) initializer(name, ndArray) 58 } 59 60 private val executor = symbol.bind(context, argDict, gradDict, "write", auxDict, null, null) 61 62 private var optimizer: Optimizer = null 63 private var paramsGrads: List[(Int, String, NDArray, AnyRef)] = null 64 private var optimizerInitialized: Boolean = false 65 66 def initOptimizer(opt: Optimizer): Unit = { 67 this.optimizer = opt 68 this.paramsGrads = gradDict.toList.zipWithIndex.map { case ((name, grad), idx) => 69 (idx, name, grad, this.optimizer.createState(idx, argDict(name))) 70 } 71 this.optimizerInitialized = true 72 } 73 74 def forward(datas: Array[NDArray], labels: Array[NDArray] = Array[NDArray]()): Unit = { 75 datas.zip(this.dataArrs).foreach { case (src, dest) => dest.set(src) } 76 labels.zip(this.labelArrs).foreach { case (src, dest) => dest.set(src) } 77 this.executor.forward(isTrain = forTraining) 78 } 79 80 def backward(outGrads: Array[NDArray]): Unit = { 81 this.executor.backward(outGrads) 82 } 83 84 def update(): Unit = { 85 assert(this.optimizerInitialized) 86 paramsGrads.foreach { case (idx, name, grad, optimState) => 87 this.optimizer.update(idx, argDict(name), grad, optimState) 88 } 89 } 90 91 def dispose(): Unit = { 92 this.executor.dispose() 93 this.argDict.foreach(_._2.dispose()) 94 this.gradDict.foreach(_._2.dispose()) 95 this.auxDict.foreach(_._2.dispose()) 96 } 97 98 def setParams(params: Map[String, NDArray]): Unit = { 99 params.foreach { case (name, arr) => 100 if (this.argDict.contains(name)) { 101 this.argDict(name).set(arr) 102 } 103 else if (this.auxDict.contains(name)) { 104 this.auxDict(name).set(arr) 105 } 106 else logger.info(name) 107 } 108 } 109 110 def loadParams(fName: String): Unit = { 111 val saveDict = NDArray.load2Map(fName) 112 var params = Map[String, NDArray]() 113 saveDict.foreach { case (k, v) => 114 val (argType, name) = { 115 val tmp = k.split(":") 116 (tmp(0), tmp(1)) 117 } 118 if (argType == "arg" || argType == "aux") { 119 params += name -> v 120 } 121 } 122 this.setParams(params) 123 } 124 125 def saveParams(fName: String): Unit = { 126 val saveDict = { 127 argDict.filter(x => !dataLabelShape.contains(x._1)) 128 .map { case (k, v) => s"arg:$k" -> v } ++ 129 auxDict.map { case (k, v) => s"aux:$k" -> v } 130 } 131 NDArray.save(fName, saveDict) 132 } 133 134 def getOutputs(): Array[NDArray] = this.executor.outputs 135 136 def getInputGrads(): Array[NDArray] = this.dataGrads 137} 138