1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnetexamples.neuralstyle
19
20import java.io.File
21
22import com.sksamuel.scrimage.{Image, Pixel}
23import com.sksamuel.scrimage.filter.GaussianBlurFilter
24import com.sksamuel.scrimage.nio.JpegWriter
25import org.apache.mxnet._
26import org.apache.mxnet.optimizer.Adam
27import org.kohsuke.args4j.{CmdLineParser, Option}
28import org.slf4j.LoggerFactory
29
30import scala.collection.JavaConverters._
31import scala.collection.mutable.ListBuffer
32
33/**
34  * An Implementation of the paper A Neural Algorithm of Artistic Style
35  */
36object NeuralStyle {
37  case class NSExecutor(executor: Executor, data: NDArray, dataGrad: NDArray)
38
39  private val logger = LoggerFactory.getLogger(classOf[NeuralStyle])
40
41  def preprocessContentImage(path: String, longEdge: Int, ctx: Context): NDArray = {
42    val img = Image.fromFile(new File(path))
43    logger.info(s"load the content image, size = ${(img.height, img.width)}")
44    val factor = longEdge.toFloat / Math.max(img.height, img.width)
45    val (newHeight, newWidth) = ((img.height * factor).toInt, (img.width * factor).toInt)
46    val resizedImg = img.scaleTo(newWidth, newHeight)
47    val sample = NDArray.empty(Shape(1, 3, newHeight, newWidth), ctx)
48    val datas = {
49      val rgbs = resizedImg.iterator.toArray.map { p =>
50        (p.red, p.green, p.blue)
51      }
52      val r = rgbs.map(_._1 - 123.68f)
53      val g = rgbs.map(_._2 - 116.779f)
54      val b = rgbs.map(_._3 - 103.939f)
55      r ++ g ++ b
56    }
57    sample.set(datas)
58    logger.info(s"resize the content image to ${(newHeight, newWidth)}")
59    sample
60  }
61
62  def preprocessStyleImage(path: String, shape: Shape, ctx: Context): NDArray = {
63    val img = Image.fromFile(new File(path))
64    val resizedImg = img.scaleTo(shape(3), shape(2))
65    val sample = NDArray.empty(Shape(1, 3, shape(2), shape(3)), ctx)
66    val datas = {
67      val rgbs = resizedImg.iterator.toArray.map { p =>
68        (p.red, p.green, p.blue)
69      }
70      val r = rgbs.map(_._1 - 123.68f)
71      val g = rgbs.map(_._2 - 116.779f)
72      val b = rgbs.map(_._3 - 103.939f)
73      r ++ g ++ b
74    }
75    sample.set(datas)
76    sample
77  }
78
79  def clip(array: Array[Float]): Array[Float] = array.map { a =>
80    if (a < 0) 0f
81    else if (a > 255) 255f
82    else a
83  }
84
85  def postprocessImage(img: NDArray): Image = {
86    val datas = img.toArray
87    val spatialSize = img.shape(2) * img.shape(3)
88    val r = clip(datas.take(spatialSize).map(_ + 123.68f))
89    val g = clip(datas.drop(spatialSize).take(spatialSize).map(_ + 116.779f))
90    val b = clip(datas.takeRight(spatialSize).map(_ + 103.939f))
91    val pixels = for (i <- 0 until spatialSize)
92      yield Pixel(r(i).toInt, g(i).toInt, b(i).toInt, 255)
93    Image(img.shape(3), img.shape(2), pixels.toArray)
94  }
95
96  def saveImage(img: NDArray, filename: String, radius: Int): Unit = {
97    logger.info(s"save output to $filename")
98    val out = postprocessImage(img)
99    val gauss = GaussianBlurFilter(radius).op
100    val result = Image(out.width, out.height)
101    gauss.filter(out.awt, result.awt)
102    result.output(filename)(JpegWriter())
103  }
104
105  def styleGramSymbol(inputSize: (Int, Int), style: Symbol): (Symbol, List[Int]) = {
106    val (_, outputShape, _) = style.inferShape(
107      Map("data" -> Shape(1, 3, inputSize._1, inputSize._2)))
108    var gramList = List[Symbol]()
109    var gradScale = List[Int]()
110    for (i <- 0 until style.listOutputs().length) {
111      val shape = outputShape(i)
112      val x = Symbol.api.Reshape(data = Some(style.get(i)),
113        target_shape = Some(Shape(shape(1), shape(2) * shape(3))))
114      val gram = Symbol.api.FullyConnected(data = Some(x), weight = Some(x),
115        no_bias = Some(true), num_hidden = shape(1))
116      x.dispose()
117      gramList = gramList :+ gram
118      gradScale = gradScale :+ (shape(1) * shape(2) * shape(3) * shape(1))
119    }
120    (Symbol.Group(gramList: _*), gradScale)
121  }
122
123  def getLoss(gram: Symbol, content: Symbol): (Symbol, Symbol) = {
124    var gramLoss = ListBuffer[Symbol]()
125    for (i <- 0 until gram.listOutputs().length) {
126      val gvar = Symbol.Variable(s"target_gram_$i")
127      Symbol.api.square(data = Some(gvar - gram.get(i)))
128      gramLoss += Symbol.api.sum(
129        Some(Symbol.api.square(data = Some(gvar - gram.get(i))))
130      )
131      gvar.dispose()
132    }
133    gram.dispose()
134    val cvar = Symbol.Variable("target_content")
135    val contentLoss = Symbol.api.sum(
136      Some(Symbol.api.square(Some(cvar - content)))
137    )
138    (Symbol.Group(gramLoss: _*), contentLoss)
139  }
140
141  def getTvGradExecutor(img: NDArray, ctx: Context, tvWeight: Float): scala.Option[Executor] = {
142    // create TV gradient executor with input binded on img
143    if (tvWeight <= 0.0f) None
144
145    val nChannel = img.shape(1)
146    val sImg = Symbol.Variable("img")
147    val sKernel = Symbol.Variable("kernel")
148    val channels = Symbol.api.SliceChannel(data = Some(sImg), num_outputs = nChannel)
149    val result = (0 until nChannel).map { i =>
150      Symbol.api.Convolution(data = Some(channels.get(i)), weight = Some(sKernel),
151        num_filter = 1, kernel = Shape(3, 3), pad = Some(Shape(1, 1)), no_bias = Some(true),
152        stride = Some(Shape(1, 1)))
153    }.toArray
154    val out = Symbol.api.Concat(result, result.length) * tvWeight
155    val kernel = {
156      val tmp = NDArray.empty(Shape(1, 1, 3, 3), ctx)
157      tmp.set(Array[Float](0, -1, 0, -1, 4, -1, 0, -1, 0))
158      tmp / 0.8f
159    }
160    Some(out.bind(ctx, Map("img" -> img, "kernel" -> kernel)))
161  }
162
163  def twoNorm(array: Array[Float]): Float = {
164    Math.sqrt(array.map(x => x * x).sum.toDouble).toFloat
165  }
166
167  //scalastyle:off
168  def runTraining(model : String, contentImage : String, styleImage: String, dev : Context,
169                  modelPath : String, outputDir : String, styleWeight : Float,
170                  contentWeight : Float, tvWeight : Float, gaussianRadius : Int,
171                  lr: Float, maxNumEpochs: Int, maxLongEdge: Int,
172                  saveEpochs : Int, stopEps: Float) : Unit = {
173    ResourceScope.using() {
174      val contentNp = preprocessContentImage(contentImage, maxLongEdge, dev)
175      val styleNp = preprocessStyleImage(styleImage, contentNp.shape, dev)
176      val size = (contentNp.shape(2), contentNp.shape(3))
177
178      val (style, content) = ModelVgg19.getSymbol
179      val (gram, gScale) = styleGramSymbol(size, style)
180      var modelExecutor = ModelVgg19.getExecutor(gram, content, modelPath, size, dev)
181
182      modelExecutor.data.set(styleNp)
183      modelExecutor.executor.forward()
184
185      val styleArray = modelExecutor.style.map(_.copyTo(Context.cpu()))
186      modelExecutor.data.set(contentNp)
187      modelExecutor.executor.forward()
188      val contentArray = modelExecutor.content.copyTo(Context.cpu())
189
190      // delete the executor
191      modelExecutor.argDict.foreach(ele => ele._2.dispose())
192      modelExecutor.content.dispose()
193      modelExecutor.data.dispose()
194      modelExecutor.dataGrad.dispose()
195      modelExecutor.style.foreach(_.dispose())
196      modelExecutor.executor.dispose()
197      modelExecutor = null
198
199      val (styleLoss, contentLoss) = getLoss(gram, content)
200      modelExecutor = ModelVgg19.getExecutor(
201        styleLoss, contentLoss, modelPath, size, dev)
202
203      val gradArray = {
204        var tmpGA = Array[NDArray]()
205        for (i <- 0 until styleArray.length) {
206          modelExecutor.argDict(s"target_gram_$i").set(styleArray(i))
207          tmpGA = tmpGA :+ NDArray.ones(Shape(1), dev) * (styleWeight / gScale(i))
208        }
209        tmpGA :+ NDArray.ones(Shape(1), dev) * contentWeight
210      }
211
212      modelExecutor.argDict("target_content").set(contentArray)
213
214      // train
215      val img = Random.uniform(-0.1f, 0.1f, contentNp.shape, dev)
216      val lrFS = new FactorScheduler(step = 10, factor = 0.9f)
217
218      saveImage(contentNp, s"${outputDir}/input.jpg", gaussianRadius)
219      saveImage(styleNp, s"${outputDir}/style.jpg", gaussianRadius)
220
221      val optimizer = new Adam(
222        learningRate = lr,
223        wd = 0.005f,
224        lrScheduler = lrFS)
225      val optimState = optimizer.createState(0, img)
226
227      logger.info(s"start training arguments")
228
229      var oldImg = img.copyTo(dev)
230      val clipNorm = img.shape.toVector.reduce(_ * _)
231      val tvGradExecutor = getTvGradExecutor(img, dev, tvWeight)
232      var eps = 0f
233      var trainingDone = false
234      var e = 0
235      while (e < maxNumEpochs && !trainingDone) {
236        modelExecutor.data.set(img)
237        modelExecutor.executor.forward()
238        modelExecutor.executor.backward(gradArray)
239
240        val gNorm = NDArray.norm(modelExecutor.dataGrad).toScalar
241        if (gNorm > clipNorm) {
242          modelExecutor.dataGrad.set(modelExecutor.dataGrad * (clipNorm / gNorm))
243        }
244        tvGradExecutor match {
245          case Some(executor) => {
246            executor.forward()
247            optimizer.update(0, img,
248              modelExecutor.dataGrad + executor.outputs(0),
249              optimState)
250          }
251          case None =>
252            optimizer.update(0, img, modelExecutor.dataGrad, optimState)
253        }
254        eps = (NDArray.norm(oldImg - img) / NDArray.norm(img)).toScalar
255        oldImg.set(img)
256        logger.info(s"epoch $e, relative change $eps")
257
258        if (eps < stopEps) {
259          logger.info("eps < args.stop_eps, training finished")
260          trainingDone = true
261        }
262        if ((e + 1) % saveEpochs == 0) {
263          saveImage(img, s"${outputDir}/tmp_${e + 1}.jpg", gaussianRadius)
264        }
265        e = e + 1
266      }
267      saveImage(img, s"${outputDir}/out.jpg", gaussianRadius)
268      logger.info("Finish fit ...")
269    }
270  }
271
272  def main(args: Array[String]): Unit = {
273    val alle = new NeuralStyle
274    val parser: CmdLineParser = new CmdLineParser(alle)
275    try {
276      parser.parseArgument(args.toList.asJava)
277      assert(alle.contentImage != null && alle.styleImage != null
278        && alle.modelPath != null && alle.outputDir != null)
279
280      val dev = if (alle.gpu >= 0) Context.gpu(alle.gpu) else Context.cpu(0)
281      runTraining(alle.model, alle.contentImage, alle.styleImage, dev, alle.modelPath,
282        alle.outputDir, alle.styleWeight, alle.contentWeight, alle.tvWeight,
283        alle.gaussianRadius, alle.lr, alle.maxNumEpochs, alle.maxLongEdge,
284        alle.saveEpochs, alle.stopEps)
285    } catch {
286      case ex: Exception => {
287        logger.error(ex.getMessage, ex)
288        parser.printUsage(System.err)
289        sys.exit(1)
290      }
291    }
292  }
293}
294
295class NeuralStyle {
296  @Option(name = "--model", usage = "the pretrained model to use: ['vgg']")
297  private val model: String = "vgg19"
298  @Option(name = "--content-image", usage = "the content image")
299  private val contentImage: String = null
300  @Option(name = "--style-image", usage = "the style image")
301  private val styleImage: String = null
302  @Option(name = "--model-path", usage = "the model file path")
303  private val modelPath: String = null
304  @Option(name = "--stop-eps", usage = "stop if the relative chanage is less than eps")
305  private val stopEps: Float = 0.0005f
306  @Option(name = "--content-weight", usage = "the weight for the content image")
307  private val contentWeight: Float = 20f
308  @Option(name = "--style-weight", usage = "the weight for the style image")
309  private val styleWeight: Float = 1f
310  @Option(name = "--tv-weight", usage = "the magtitute on TV loss")
311  private val tvWeight: Float = 0.01f
312  @Option(name = "--max-num-epochs", usage = "the maximal number of training epochs")
313  private val maxNumEpochs: Int = 1000
314  @Option(name = "--max-long-edge", usage = "resize the content image")
315  private val maxLongEdge: Int = 600
316  @Option(name = "--lr", usage = "the initial learning rate")
317  private val lr: Float = 10f
318  @Option(name = "--gpu", usage = "which gpu card to use, -1 means using cpu")
319  private val gpu: Int = 0
320  @Option(name = "--output-dir", usage = "the output directory")
321  private val outputDir: String = null
322  @Option(name = "--save-epochs", usage = "save the output every n epochs")
323  private val saveEpochs: Int = 50
324  @Option(name = "--gaussian-radius", usage = "the gaussian blur filter radius")
325  private val gaussianRadius: Int = 1
326}
327