1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnetexamples.gan 19 20import org.apache.mxnet.{Context, CustomMetric, DataBatch, IO, NDArray, ResourceScope, Shape, Symbol, Xavier} 21import org.apache.mxnet.optimizer.Adam 22import org.kohsuke.args4j.{CmdLineParser, Option} 23import org.slf4j.LoggerFactory 24 25import scala.collection.JavaConverters._ 26 27object GanMnist { 28 29 private val logger = LoggerFactory.getLogger(classOf[GanMnist]) 30 31 // a deconv layer that enlarges the feature map 32 def deconv2D(data: Symbol, iShape: Shape, oShape: Shape, 33 kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)): Symbol = { 34 val targetShape = Shape(oShape(oShape.length - 2), oShape(oShape.length - 1)) 35 val net = Symbol.api.Deconvolution(data = Some(data), kernel = Shape(kShape._1, kShape._2), 36 stride = Some(Shape(stride._1, stride._2)), target_shape = Some(targetShape), 37 num_filter = oShape(0), no_bias = Some(true), name = name) 38 net 39 } 40 41 def deconv2DBnRelu(data: Symbol, prefix: String, iShape: Shape, 42 oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f + 1e-12f): Symbol = { 43 var net = deconv2D(data, iShape, oShape, kShape, name = s"${prefix}_deconv") 44 net = Symbol.api.BatchNorm(name = s"${prefix}_bn", data = Some(net), 45 fix_gamma = Some(true), eps = Some(eps)) 46 net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = s"${prefix}_act") 47 net 48 } 49 50 def deconv2DAct(data: Symbol, prefix: String, actType: String, 51 iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = { 52 var net = deconv2D(data, iShape, oShape, kShape, name = s"${prefix}_deconv") 53 net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = s"${prefix}_act") 54 net 55 } 56 57 def makeDcganSym(oShape: Shape, ngf: Int = 128, finalAct: String = "sigmoid", 58 eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = { 59 60 val code = Symbol.Variable("rand") 61 var net = Symbol.api.FullyConnected(data = Some(code), num_hidden = 4 * 4 * ngf * 4, 62 no_bias = Some(true), name = " g1") 63 net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = "gact1") 64 // 4 x 4 65 net = Symbol.api.Reshape(data = Some(net), shape = Some(Shape(-1, ngf * 4, 4, 4))) 66 // 8 x 8 67 net = deconv2DBnRelu(net, prefix = "g2", 68 iShape = Shape(ngf * 4, 4, 4), oShape = Shape(ngf * 2, 8, 8), kShape = (3, 3)) 69 // 14x14 70 net = deconv2DBnRelu(net, prefix = "g3", 71 iShape = Shape(ngf * 2, 8, 8), oShape = Shape(ngf, 14, 14), kShape = (4, 4)) 72 // 28x28 73 val gout = deconv2DAct(net, prefix = "g4", actType = finalAct, iShape = Shape(ngf, 14, 14), 74 oShape = Shape(oShape.toArray.takeRight(3)), kShape = (4, 4)) 75 76 val data = Symbol.Variable("data") 77 // 28 x 28 78 val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5), 79 num_filter = 20, name = "conv1") 80 val tanh1 = Symbol.api.Activation(data = Some(conv1), act_type = "tanh") 81 val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"), 82 kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2))) 83 // second conv 84 val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 5), 85 num_filter = 50, name = "conv2") 86 val tanh2 = Symbol.api.Activation(data = Some(conv2), act_type = "tanh") 87 val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"), 88 kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2))) 89 var d5 = Symbol.api.Flatten(data = Some(pool2)) 90 d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 500, name = "fc1") 91 d5 = Symbol.api.Activation(data = Some(d5), act_type = "tanh") 92 d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 1, name = "fc_dloss") 93 val dloss = Symbol.api.LogisticRegressionOutput(data = Some(d5), name = "dloss") 94 95 (gout, dloss) 96 } 97 98 // Evaluation 99 def ferr(label: NDArray, pred: NDArray): Float = { 100 val predArr = pred.toArray.map(p => if (p > 0.5) 1f else 0f) 101 val labelArr = label.toArray 102 labelArr.zip(predArr).map { case (l, p) => Math.abs(l - p) }.sum / label.shape(0) 103 } 104 105 def runTraining(dataPath : String, context : Context, 106 outputPath : String, numEpoch : Int): Float = { 107 val output = ResourceScope.using() { 108 val lr = 0.0005f 109 val beta1 = 0.5f 110 val batchSize = 100 111 val randShape = Shape(batchSize, 100) 112 val dataShape = Shape(batchSize, 1, 28, 28) 113 114 val (symGen, symDec) = 115 makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid") 116 117 val gMod = new GANModule( 118 symGen, 119 symDec, 120 context = context, 121 dataShape = dataShape, 122 codeShape = randShape) 123 124 gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f)) 125 gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f)) 126 127 gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1)) 128 129 val params = Map( 130 "image" -> s"$dataPath/train-images-idx3-ubyte", 131 "label" -> s"$dataPath/train-labels-idx1-ubyte", 132 "input_shape" -> s"(1, 28, 28)", 133 "batch_size" -> s"$batchSize", 134 "shuffle" -> "True" 135 ) 136 137 val mnistIter = IO.MNISTIter(params) 138 139 val metricAcc = new CustomMetric(ferr, "ferr") 140 141 var t = 0 142 var dataBatch: DataBatch = null 143 var acc = 0.0f 144 for (epoch <- 0 until numEpoch) { 145 mnistIter.reset() 146 metricAcc.reset() 147 t = 0 148 while (mnistIter.hasNext) { 149 dataBatch = mnistIter.next() 150 ResourceScope.using() { 151 gMod.update(dataBatch) 152 gMod.dLabel.set(0f) 153 metricAcc.update(Array(gMod.dLabel), gMod.outputsFake) 154 gMod.dLabel.set(1f) 155 metricAcc.update(Array(gMod.dLabel), gMod.outputsReal) 156 157 if (t % 50 == 0) { 158 val (name, value) = metricAcc.get 159 acc = value(0) 160 logger.info(s"epoch: $epoch, iter $t, metric=${value.mkString(" ")}") 161 Viz.imSave("gout", outputPath, gMod.tempOutG(0), flip = true) 162 val diff = gMod.tempDiffD 163 val arr = diff.toArray 164 val mean = arr.sum / arr.length 165 val std = { 166 val tmpA = arr.map(a => (a - mean) * (a - mean)) 167 Math.sqrt(tmpA.sum / tmpA.length).toFloat 168 } 169 diff.set((diff - mean) / std + 0.5f) 170 Viz.imSave("diff", outputPath, diff, flip = true) 171 Viz.imSave("data", outputPath, dataBatch.data(0), flip = true) 172 } 173 } 174 dataBatch.dispose() 175 t += 1 176 } 177 } 178 acc 179 } 180 output 181 } 182 183 def main(args: Array[String]): Unit = { 184 val anst = new GanMnist 185 val parser: CmdLineParser = new CmdLineParser(anst) 186 try { 187 parser.parseArgument(args.toList.asJava) 188 189 val dataPath = if (anst.mnistDataPath == null) System.getenv("MXNET_HOME") 190 else anst.mnistDataPath 191 192 assert(dataPath != null) 193 val context = if (anst.gpu == -1) Context.cpu() else Context.gpu(anst.gpu) 194 195 runTraining(dataPath, context, anst.outputPath, 100) 196 } catch { 197 case ex: Exception => { 198 logger.error(ex.getMessage, ex) 199 parser.printUsage(System.err) 200 sys.exit(1) 201 } 202 } 203 } 204} 205 206class GanMnist { 207 @Option(name = "--mnist-data-path", usage = "the mnist data path") 208 private val mnistDataPath: String = null 209 @Option(name = "--output-path", usage = "the path to save the generated result") 210 private val outputPath: String = null 211 @Option(name = "--gpu", usage = "which gpu card to use, default is -1, means using cpu") 212 private val gpu: Int = -1 213} 214