1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnetexamples.gan
19
20import org.apache.mxnet.{Context, CustomMetric, DataBatch, IO, NDArray, ResourceScope, Shape, Symbol, Xavier}
21import org.apache.mxnet.optimizer.Adam
22import org.kohsuke.args4j.{CmdLineParser, Option}
23import org.slf4j.LoggerFactory
24
25import scala.collection.JavaConverters._
26
27object GanMnist {
28
29  private val logger = LoggerFactory.getLogger(classOf[GanMnist])
30
31  // a deconv layer that enlarges the feature map
32  def deconv2D(data: Symbol, iShape: Shape, oShape: Shape,
33               kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)): Symbol = {
34    val targetShape = Shape(oShape(oShape.length - 2), oShape(oShape.length - 1))
35    val net = Symbol.api.Deconvolution(data = Some(data), kernel = Shape(kShape._1, kShape._2),
36      stride = Some(Shape(stride._1, stride._2)), target_shape = Some(targetShape),
37      num_filter = oShape(0), no_bias = Some(true), name = name)
38    net
39  }
40
41  def deconv2DBnRelu(data: Symbol, prefix: String, iShape: Shape,
42                     oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f + 1e-12f): Symbol = {
43    var net = deconv2D(data, iShape, oShape, kShape, name = s"${prefix}_deconv")
44    net = Symbol.api.BatchNorm(name = s"${prefix}_bn", data = Some(net),
45      fix_gamma = Some(true), eps = Some(eps))
46    net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = s"${prefix}_act")
47    net
48  }
49
50  def deconv2DAct(data: Symbol, prefix: String, actType: String,
51                  iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = {
52    var net = deconv2D(data, iShape, oShape, kShape, name = s"${prefix}_deconv")
53    net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = s"${prefix}_act")
54    net
55  }
56
57  def makeDcganSym(oShape: Shape, ngf: Int = 128, finalAct: String = "sigmoid",
58                   eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = {
59
60    val code = Symbol.Variable("rand")
61    var net = Symbol.api.FullyConnected(data = Some(code), num_hidden = 4 * 4 * ngf * 4,
62      no_bias = Some(true), name = " g1")
63    net = Symbol.api.Activation(data = Some(net), act_type = "relu", name = "gact1")
64    // 4 x 4
65    net = Symbol.api.Reshape(data = Some(net), shape = Some(Shape(-1, ngf * 4, 4, 4)))
66    // 8 x 8
67    net = deconv2DBnRelu(net, prefix = "g2",
68      iShape = Shape(ngf * 4, 4, 4), oShape = Shape(ngf * 2, 8, 8), kShape = (3, 3))
69    // 14x14
70    net = deconv2DBnRelu(net, prefix = "g3",
71      iShape = Shape(ngf * 2, 8, 8), oShape = Shape(ngf, 14, 14), kShape = (4, 4))
72    // 28x28
73    val gout = deconv2DAct(net, prefix = "g4", actType = finalAct, iShape = Shape(ngf, 14, 14),
74      oShape = Shape(oShape.toArray.takeRight(3)), kShape = (4, 4))
75
76    val data = Symbol.Variable("data")
77    // 28 x 28
78    val conv1 = Symbol.api.Convolution(data = Some(data), kernel = Shape(5, 5),
79      num_filter = 20, name = "conv1")
80    val tanh1 = Symbol.api.Activation(data = Some(conv1), act_type = "tanh")
81    val pool1 = Symbol.api.Pooling(data = Some(tanh1), pool_type = Some("max"),
82      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
83    // second conv
84    val conv2 = Symbol.api.Convolution(data = Some(pool1), kernel = Shape(5, 5),
85      num_filter = 50, name = "conv2")
86    val tanh2 = Symbol.api.Activation(data = Some(conv2), act_type = "tanh")
87    val pool2 = Symbol.api.Pooling(data = Some(tanh2), pool_type = Some("max"),
88      kernel = Some(Shape(2, 2)), stride = Some(Shape(2, 2)))
89    var d5 = Symbol.api.Flatten(data = Some(pool2))
90    d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 500, name = "fc1")
91    d5 = Symbol.api.Activation(data = Some(d5), act_type = "tanh")
92    d5 = Symbol.api.FullyConnected(data = Some(d5), num_hidden = 1, name = "fc_dloss")
93    val dloss = Symbol.api.LogisticRegressionOutput(data = Some(d5), name = "dloss")
94
95    (gout, dloss)
96  }
97
98  // Evaluation
99  def ferr(label: NDArray, pred: NDArray): Float = {
100    val predArr = pred.toArray.map(p => if (p > 0.5) 1f else 0f)
101    val labelArr = label.toArray
102    labelArr.zip(predArr).map { case (l, p) => Math.abs(l - p) }.sum / label.shape(0)
103  }
104
105  def runTraining(dataPath : String, context : Context,
106                  outputPath : String, numEpoch : Int): Float = {
107    val output = ResourceScope.using() {
108      val lr = 0.0005f
109      val beta1 = 0.5f
110      val batchSize = 100
111      val randShape = Shape(batchSize, 100)
112      val dataShape = Shape(batchSize, 1, 28, 28)
113
114      val (symGen, symDec) =
115        makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")
116
117      val gMod = new GANModule(
118        symGen,
119        symDec,
120        context = context,
121        dataShape = dataShape,
122        codeShape = randShape)
123
124      gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
125      gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))
126
127      gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))
128
129      val params = Map(
130        "image" -> s"$dataPath/train-images-idx3-ubyte",
131        "label" -> s"$dataPath/train-labels-idx1-ubyte",
132        "input_shape" -> s"(1, 28, 28)",
133        "batch_size" -> s"$batchSize",
134        "shuffle" -> "True"
135      )
136
137      val mnistIter = IO.MNISTIter(params)
138
139      val metricAcc = new CustomMetric(ferr, "ferr")
140
141      var t = 0
142      var dataBatch: DataBatch = null
143      var acc = 0.0f
144      for (epoch <- 0 until numEpoch) {
145        mnistIter.reset()
146        metricAcc.reset()
147        t = 0
148        while (mnistIter.hasNext) {
149          dataBatch = mnistIter.next()
150          ResourceScope.using() {
151            gMod.update(dataBatch)
152            gMod.dLabel.set(0f)
153            metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
154            gMod.dLabel.set(1f)
155            metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)
156
157            if (t % 50 == 0) {
158              val (name, value) = metricAcc.get
159              acc = value(0)
160              logger.info(s"epoch: $epoch, iter $t, metric=${value.mkString(" ")}")
161              Viz.imSave("gout", outputPath, gMod.tempOutG(0), flip = true)
162              val diff = gMod.tempDiffD
163              val arr = diff.toArray
164              val mean = arr.sum / arr.length
165              val std = {
166                val tmpA = arr.map(a => (a - mean) * (a - mean))
167                Math.sqrt(tmpA.sum / tmpA.length).toFloat
168              }
169              diff.set((diff - mean) / std + 0.5f)
170              Viz.imSave("diff", outputPath, diff, flip = true)
171              Viz.imSave("data", outputPath, dataBatch.data(0), flip = true)
172            }
173          }
174          dataBatch.dispose()
175          t += 1
176        }
177      }
178      acc
179    }
180    output
181  }
182
183  def main(args: Array[String]): Unit = {
184    val anst = new GanMnist
185    val parser: CmdLineParser = new CmdLineParser(anst)
186    try {
187      parser.parseArgument(args.toList.asJava)
188
189      val dataPath = if (anst.mnistDataPath == null) System.getenv("MXNET_HOME")
190      else anst.mnistDataPath
191
192      assert(dataPath != null)
193      val context = if (anst.gpu == -1) Context.cpu() else Context.gpu(anst.gpu)
194
195      runTraining(dataPath, context, anst.outputPath, 100)
196    } catch {
197      case ex: Exception => {
198        logger.error(ex.getMessage, ex)
199        parser.printUsage(System.err)
200        sys.exit(1)
201      }
202    }
203  }
204}
205
206class GanMnist {
207  @Option(name = "--mnist-data-path", usage = "the mnist data path")
208  private val mnistDataPath: String = null
209  @Option(name = "--output-path", usage = "the path to save the generated result")
210  private val outputPath: String = null
211  @Option(name = "--gpu", usage = "which gpu card to use, default is -1, means using cpu")
212  private val gpu: Int = -1
213}
214