1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnetexamples.neuralstyle.end2end
19
20import java.io.File
21
22import org.apache.mxnet.{Context, Executor, NDArray, ResourceScope, Shape, Symbol}
23import org.apache.mxnet.optimizer.SGD
24import org.kohsuke.args4j.{CmdLineParser, Option}
25import org.slf4j.LoggerFactory
26
27import scala.collection.JavaConverters._
28import scala.util.Random
29
30
31object BoostTrain {
32
33  private val logger = LoggerFactory.getLogger(classOf[BoostTrain])
34
35  def getTvGradExecutor(img: NDArray, ctx: Context, tvWeight: Float): Executor = {
36    // create TV gradient executor with input binded on img
37    if (tvWeight <= 0.0f) null
38
39    val nChannel = img.shape(1)
40    val sImg = Symbol.Variable("img")
41    val sKernel = Symbol.Variable("kernel")
42    val channels = Symbol.api.SliceChannel(data = Some(sImg), num_outputs = nChannel)
43    val toConcat = (0 until nChannel).map( i =>
44      Symbol.api.Convolution(data = Some(channels.get(i)), weight = Some(sKernel),
45        num_filter = 1, kernel = Shape(3, 3), pad = Some(Shape(1, 1)),
46        no_bias = Some(true), stride = Some(Shape(1, 1)))
47    ).toArray
48    val out = Symbol.api.Concat(data = toConcat, num_args = toConcat.length) * tvWeight
49    val kernel = {
50      val tmp = NDArray.empty(Shape(1, 1, 3, 3), ctx)
51      tmp.set(Array[Float](0, -1, 0, -1, 4, -1, 0, -1, 0))
52      tmp / 8.0f
53    }
54    out.bind(ctx, Map("img" -> img, "kernel" -> kernel))
55  }
56
57  def runTraining(dataPath : String, vggModelPath: String, ctx : Context,
58                  styleImage : String, saveModelPath : String) : Unit = {
59    ResourceScope.using() {
60      // params
61      val vggParams = NDArray.load2Map(vggModelPath)
62      val styleWeight = 1.2f
63      val contentWeight = 10f
64      val dShape = Shape(1, 3, 384, 384)
65      val clipNorm = 0.05f * dShape.product
66      val modelPrefix = "v3"
67      // init style
68      val styleNp = DataProcessing.preprocessStyleImage(styleImage, dShape, ctx)
69      var styleMod = Basic.getStyleModule("style", dShape, ctx, vggParams)
70      styleMod.forward(Array(styleNp))
71      val styleArray = styleMod.getOutputs().map(_.copyTo(Context.cpu()))
72      styleMod.dispose()
73      styleMod = null
74
75      // content
76      val contentMod = Basic.getContentModule("content", dShape, ctx, vggParams)
77
78      // loss
79      val (loss, gScale) = Basic.getLossModule("loss", dShape, ctx, vggParams)
80      val extraArgs = (0 until styleArray.length)
81        .map(i => s"target_gram_$i" -> styleArray(i)).toMap
82      loss.setParams(extraArgs)
83      var gradArray = Array[NDArray]()
84      for (i <- 0 until styleArray.length) {
85        gradArray = gradArray :+ (NDArray.ones(Shape(1), ctx) * (styleWeight / gScale(i)))
86      }
87      gradArray = gradArray :+ (NDArray.ones(Shape(1), ctx) * contentWeight)
88
89      // generator
90      val gens = Array(
91        GenV4.getModule("g0", dShape, ctx),
92        GenV3.getModule("g1", dShape, ctx),
93        GenV3.getModule("g2", dShape, ctx),
94        GenV4.getModule("g3", dShape, ctx)
95      )
96      gens.foreach { gen =>
97        val opt = new SGD(learningRate = 1e-4f,
98          momentum = 0.9f,
99          wd = 5e-3f,
100          clipGradient = 5f)
101        gen.initOptimizer(opt)
102      }
103
104      var filelist = new File(dataPath).list().toList
105      val numImage = filelist.length
106      logger.info(s"Dataset size: $numImage")
107
108      val tvWeight = 1e-2f
109
110      val startEpoch = 0
111      val endEpoch = 3
112
113      for (k <- 0 until gens.length) {
114        val path = new File(s"${saveModelPath}/$k")
115        if (!path.exists()) path.mkdir()
116      }
117
118      // train
119      for (i <- startEpoch until endEpoch) {
120        ResourceScope.using() {
121          filelist = Random.shuffle(filelist)
122          for (idx <- filelist.indices) {
123            var dataArray = Array[NDArray]()
124            var lossGradArray = Array[NDArray]()
125            val data =
126              DataProcessing.preprocessContentImage(s"${dataPath}/${filelist(idx)}", dShape, ctx)
127            dataArray = dataArray :+ data
128            // get content
129            contentMod.forward(Array(data))
130            // set target content
131            loss.setParams(Map("target_content" -> contentMod.getOutputs()(0)))
132            // gen_forward
133            for (k <- 0 until gens.length) {
134              gens(k).forward(dataArray.takeRight(1))
135              dataArray = dataArray :+ gens(k).getOutputs()(0)
136              // loss forward
137              loss.forward(dataArray.takeRight(1))
138              loss.backward(gradArray)
139              lossGradArray = lossGradArray :+ loss.getInputGrads()(0)
140            }
141            val grad = NDArray.zeros(data.shape, ctx)
142            for (k <- gens.length - 1 to 0 by -1) {
143              val tvGradExecutor = getTvGradExecutor(gens(k).getOutputs()(0), ctx, tvWeight)
144              tvGradExecutor.forward()
145              grad += lossGradArray(k) + tvGradExecutor.outputs(0)
146              val gNorm = NDArray.norm(grad)
147              if (gNorm.toScalar > clipNorm) {
148                grad *= clipNorm / gNorm.toScalar
149              }
150              gens(k).backward(Array(grad))
151              gens(k).update()
152              gNorm.dispose()
153              tvGradExecutor.dispose()
154            }
155            grad.dispose()
156            if (idx % 20 == 0) {
157              logger.info(s"Epoch $i: Image $idx")
158              for (k <- 0 until gens.length) {
159                val n = NDArray.norm(gens(k).getInputGrads()(0))
160                logger.info(s"Data Norm : ${n.toScalar / dShape.product}")
161                n.dispose()
162              }
163            }
164            if (idx % 1000 == 0) {
165              for (k <- 0 until gens.length) {
166                gens(k).saveParams(
167                  s"${saveModelPath}/$k/${modelPrefix}_" +
168                    s"${"%04d".format(i)}-${"%07d".format(idx)}.params")
169              }
170            }
171            data.dispose()
172          }
173        }
174      }
175    }
176  }
177
178  def main(args: Array[String]): Unit = {
179    val stin = new BoostTrain
180    val parser: CmdLineParser = new CmdLineParser(stin)
181    try {
182      parser.parseArgument(args.toList.asJava)
183      assert(stin.dataPath != null
184          && stin.vggModelPath != null
185          && stin.saveModelPath != null
186          && stin.styleImage != null)
187
188      val ctx = if (stin.gpu == -1) Context.cpu() else Context.gpu(stin.gpu)
189      runTraining(stin.dataPath, stin.vggModelPath, ctx, stin.styleImage, stin.saveModelPath)
190    } catch {
191      case ex: Exception => {
192        logger.error(ex.getMessage, ex)
193        parser.printUsage(System.err)
194        sys.exit(1)
195      }
196    }
197  }
198}
199
200class BoostTrain {
201  @Option(name = "--data-path", usage = "the input train data path")
202  private val dataPath: String = null
203  @Option(name = "--vgg-model-path", usage = "the pretrained model to use: ['vgg']")
204  private val vggModelPath: String = null
205  @Option(name = "--save-model-path", usage = "the save model path")
206  private val saveModelPath: String = null
207  @Option(name = "--style-image", usage = "the style image")
208  private val styleImage: String = null
209  @Option(name = "--gpu", usage = "which gpu card to use, default is -1, means using cpu")
210  private val gpu: Int = -1
211}
212