1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnetexamples.neuralstyle 19 20import java.io.File 21 22import com.sksamuel.scrimage.{Image, Pixel} 23import com.sksamuel.scrimage.filter.GaussianBlurFilter 24import com.sksamuel.scrimage.nio.JpegWriter 25import org.apache.mxnet._ 26import org.apache.mxnet.optimizer.Adam 27import org.kohsuke.args4j.{CmdLineParser, Option} 28import org.slf4j.LoggerFactory 29 30import scala.collection.JavaConverters._ 31import scala.collection.mutable.ListBuffer 32 33/** 34 * An Implementation of the paper A Neural Algorithm of Artistic Style 35 */ 36object NeuralStyle { 37 case class NSExecutor(executor: Executor, data: NDArray, dataGrad: NDArray) 38 39 private val logger = LoggerFactory.getLogger(classOf[NeuralStyle]) 40 41 def preprocessContentImage(path: String, longEdge: Int, ctx: Context): NDArray = { 42 val img = Image.fromFile(new File(path)) 43 logger.info(s"load the content image, size = ${(img.height, img.width)}") 44 val factor = longEdge.toFloat / Math.max(img.height, img.width) 45 val (newHeight, newWidth) = ((img.height * factor).toInt, (img.width * factor).toInt) 46 val resizedImg = img.scaleTo(newWidth, newHeight) 47 val sample = NDArray.empty(Shape(1, 3, newHeight, newWidth), ctx) 48 val datas = { 49 val rgbs = resizedImg.iterator.toArray.map { p => 50 (p.red, p.green, p.blue) 51 } 52 val r = rgbs.map(_._1 - 123.68f) 53 val g = rgbs.map(_._2 - 116.779f) 54 val b = rgbs.map(_._3 - 103.939f) 55 r ++ g ++ b 56 } 57 sample.set(datas) 58 logger.info(s"resize the content image to ${(newHeight, newWidth)}") 59 sample 60 } 61 62 def preprocessStyleImage(path: String, shape: Shape, ctx: Context): NDArray = { 63 val img = Image.fromFile(new File(path)) 64 val resizedImg = img.scaleTo(shape(3), shape(2)) 65 val sample = NDArray.empty(Shape(1, 3, shape(2), shape(3)), ctx) 66 val datas = { 67 val rgbs = resizedImg.iterator.toArray.map { p => 68 (p.red, p.green, p.blue) 69 } 70 val r = rgbs.map(_._1 - 123.68f) 71 val g = rgbs.map(_._2 - 116.779f) 72 val b = rgbs.map(_._3 - 103.939f) 73 r ++ g ++ b 74 } 75 sample.set(datas) 76 sample 77 } 78 79 def clip(array: Array[Float]): Array[Float] = array.map { a => 80 if (a < 0) 0f 81 else if (a > 255) 255f 82 else a 83 } 84 85 def postprocessImage(img: NDArray): Image = { 86 val datas = img.toArray 87 val spatialSize = img.shape(2) * img.shape(3) 88 val r = clip(datas.take(spatialSize).map(_ + 123.68f)) 89 val g = clip(datas.drop(spatialSize).take(spatialSize).map(_ + 116.779f)) 90 val b = clip(datas.takeRight(spatialSize).map(_ + 103.939f)) 91 val pixels = for (i <- 0 until spatialSize) 92 yield Pixel(r(i).toInt, g(i).toInt, b(i).toInt, 255) 93 Image(img.shape(3), img.shape(2), pixels.toArray) 94 } 95 96 def saveImage(img: NDArray, filename: String, radius: Int): Unit = { 97 logger.info(s"save output to $filename") 98 val out = postprocessImage(img) 99 val gauss = GaussianBlurFilter(radius).op 100 val result = Image(out.width, out.height) 101 gauss.filter(out.awt, result.awt) 102 result.output(filename)(JpegWriter()) 103 } 104 105 def styleGramSymbol(inputSize: (Int, Int), style: Symbol): (Symbol, List[Int]) = { 106 val (_, outputShape, _) = style.inferShape( 107 Map("data" -> Shape(1, 3, inputSize._1, inputSize._2))) 108 var gramList = List[Symbol]() 109 var gradScale = List[Int]() 110 for (i <- 0 until style.listOutputs().length) { 111 val shape = outputShape(i) 112 val x = Symbol.api.Reshape(data = Some(style.get(i)), 113 target_shape = Some(Shape(shape(1), shape(2) * shape(3)))) 114 val gram = Symbol.api.FullyConnected(data = Some(x), weight = Some(x), 115 no_bias = Some(true), num_hidden = shape(1)) 116 x.dispose() 117 gramList = gramList :+ gram 118 gradScale = gradScale :+ (shape(1) * shape(2) * shape(3) * shape(1)) 119 } 120 (Symbol.Group(gramList: _*), gradScale) 121 } 122 123 def getLoss(gram: Symbol, content: Symbol): (Symbol, Symbol) = { 124 var gramLoss = ListBuffer[Symbol]() 125 for (i <- 0 until gram.listOutputs().length) { 126 val gvar = Symbol.Variable(s"target_gram_$i") 127 Symbol.api.square(data = Some(gvar - gram.get(i))) 128 gramLoss += Symbol.api.sum( 129 Some(Symbol.api.square(data = Some(gvar - gram.get(i)))) 130 ) 131 gvar.dispose() 132 } 133 gram.dispose() 134 val cvar = Symbol.Variable("target_content") 135 val contentLoss = Symbol.api.sum( 136 Some(Symbol.api.square(Some(cvar - content))) 137 ) 138 (Symbol.Group(gramLoss: _*), contentLoss) 139 } 140 141 def getTvGradExecutor(img: NDArray, ctx: Context, tvWeight: Float): scala.Option[Executor] = { 142 // create TV gradient executor with input binded on img 143 if (tvWeight <= 0.0f) None 144 145 val nChannel = img.shape(1) 146 val sImg = Symbol.Variable("img") 147 val sKernel = Symbol.Variable("kernel") 148 val channels = Symbol.api.SliceChannel(data = Some(sImg), num_outputs = nChannel) 149 val result = (0 until nChannel).map { i => 150 Symbol.api.Convolution(data = Some(channels.get(i)), weight = Some(sKernel), 151 num_filter = 1, kernel = Shape(3, 3), pad = Some(Shape(1, 1)), no_bias = Some(true), 152 stride = Some(Shape(1, 1))) 153 }.toArray 154 val out = Symbol.api.Concat(result, result.length) * tvWeight 155 val kernel = { 156 val tmp = NDArray.empty(Shape(1, 1, 3, 3), ctx) 157 tmp.set(Array[Float](0, -1, 0, -1, 4, -1, 0, -1, 0)) 158 tmp / 0.8f 159 } 160 Some(out.bind(ctx, Map("img" -> img, "kernel" -> kernel))) 161 } 162 163 def twoNorm(array: Array[Float]): Float = { 164 Math.sqrt(array.map(x => x * x).sum.toDouble).toFloat 165 } 166 167 //scalastyle:off 168 def runTraining(model : String, contentImage : String, styleImage: String, dev : Context, 169 modelPath : String, outputDir : String, styleWeight : Float, 170 contentWeight : Float, tvWeight : Float, gaussianRadius : Int, 171 lr: Float, maxNumEpochs: Int, maxLongEdge: Int, 172 saveEpochs : Int, stopEps: Float) : Unit = { 173 ResourceScope.using() { 174 val contentNp = preprocessContentImage(contentImage, maxLongEdge, dev) 175 val styleNp = preprocessStyleImage(styleImage, contentNp.shape, dev) 176 val size = (contentNp.shape(2), contentNp.shape(3)) 177 178 val (style, content) = ModelVgg19.getSymbol 179 val (gram, gScale) = styleGramSymbol(size, style) 180 var modelExecutor = ModelVgg19.getExecutor(gram, content, modelPath, size, dev) 181 182 modelExecutor.data.set(styleNp) 183 modelExecutor.executor.forward() 184 185 val styleArray = modelExecutor.style.map(_.copyTo(Context.cpu())) 186 modelExecutor.data.set(contentNp) 187 modelExecutor.executor.forward() 188 val contentArray = modelExecutor.content.copyTo(Context.cpu()) 189 190 // delete the executor 191 modelExecutor.argDict.foreach(ele => ele._2.dispose()) 192 modelExecutor.content.dispose() 193 modelExecutor.data.dispose() 194 modelExecutor.dataGrad.dispose() 195 modelExecutor.style.foreach(_.dispose()) 196 modelExecutor.executor.dispose() 197 modelExecutor = null 198 199 val (styleLoss, contentLoss) = getLoss(gram, content) 200 modelExecutor = ModelVgg19.getExecutor( 201 styleLoss, contentLoss, modelPath, size, dev) 202 203 val gradArray = { 204 var tmpGA = Array[NDArray]() 205 for (i <- 0 until styleArray.length) { 206 modelExecutor.argDict(s"target_gram_$i").set(styleArray(i)) 207 tmpGA = tmpGA :+ NDArray.ones(Shape(1), dev) * (styleWeight / gScale(i)) 208 } 209 tmpGA :+ NDArray.ones(Shape(1), dev) * contentWeight 210 } 211 212 modelExecutor.argDict("target_content").set(contentArray) 213 214 // train 215 val img = Random.uniform(-0.1f, 0.1f, contentNp.shape, dev) 216 val lrFS = new FactorScheduler(step = 10, factor = 0.9f) 217 218 saveImage(contentNp, s"${outputDir}/input.jpg", gaussianRadius) 219 saveImage(styleNp, s"${outputDir}/style.jpg", gaussianRadius) 220 221 val optimizer = new Adam( 222 learningRate = lr, 223 wd = 0.005f, 224 lrScheduler = lrFS) 225 val optimState = optimizer.createState(0, img) 226 227 logger.info(s"start training arguments") 228 229 var oldImg = img.copyTo(dev) 230 val clipNorm = img.shape.toVector.reduce(_ * _) 231 val tvGradExecutor = getTvGradExecutor(img, dev, tvWeight) 232 var eps = 0f 233 var trainingDone = false 234 var e = 0 235 while (e < maxNumEpochs && !trainingDone) { 236 modelExecutor.data.set(img) 237 modelExecutor.executor.forward() 238 modelExecutor.executor.backward(gradArray) 239 240 val gNorm = NDArray.norm(modelExecutor.dataGrad).toScalar 241 if (gNorm > clipNorm) { 242 modelExecutor.dataGrad.set(modelExecutor.dataGrad * (clipNorm / gNorm)) 243 } 244 tvGradExecutor match { 245 case Some(executor) => { 246 executor.forward() 247 optimizer.update(0, img, 248 modelExecutor.dataGrad + executor.outputs(0), 249 optimState) 250 } 251 case None => 252 optimizer.update(0, img, modelExecutor.dataGrad, optimState) 253 } 254 eps = (NDArray.norm(oldImg - img) / NDArray.norm(img)).toScalar 255 oldImg.set(img) 256 logger.info(s"epoch $e, relative change $eps") 257 258 if (eps < stopEps) { 259 logger.info("eps < args.stop_eps, training finished") 260 trainingDone = true 261 } 262 if ((e + 1) % saveEpochs == 0) { 263 saveImage(img, s"${outputDir}/tmp_${e + 1}.jpg", gaussianRadius) 264 } 265 e = e + 1 266 } 267 saveImage(img, s"${outputDir}/out.jpg", gaussianRadius) 268 logger.info("Finish fit ...") 269 } 270 } 271 272 def main(args: Array[String]): Unit = { 273 val alle = new NeuralStyle 274 val parser: CmdLineParser = new CmdLineParser(alle) 275 try { 276 parser.parseArgument(args.toList.asJava) 277 assert(alle.contentImage != null && alle.styleImage != null 278 && alle.modelPath != null && alle.outputDir != null) 279 280 val dev = if (alle.gpu >= 0) Context.gpu(alle.gpu) else Context.cpu(0) 281 runTraining(alle.model, alle.contentImage, alle.styleImage, dev, alle.modelPath, 282 alle.outputDir, alle.styleWeight, alle.contentWeight, alle.tvWeight, 283 alle.gaussianRadius, alle.lr, alle.maxNumEpochs, alle.maxLongEdge, 284 alle.saveEpochs, alle.stopEps) 285 } catch { 286 case ex: Exception => { 287 logger.error(ex.getMessage, ex) 288 parser.printUsage(System.err) 289 sys.exit(1) 290 } 291 } 292 } 293} 294 295class NeuralStyle { 296 @Option(name = "--model", usage = "the pretrained model to use: ['vgg']") 297 private val model: String = "vgg19" 298 @Option(name = "--content-image", usage = "the content image") 299 private val contentImage: String = null 300 @Option(name = "--style-image", usage = "the style image") 301 private val styleImage: String = null 302 @Option(name = "--model-path", usage = "the model file path") 303 private val modelPath: String = null 304 @Option(name = "--stop-eps", usage = "stop if the relative chanage is less than eps") 305 private val stopEps: Float = 0.0005f 306 @Option(name = "--content-weight", usage = "the weight for the content image") 307 private val contentWeight: Float = 20f 308 @Option(name = "--style-weight", usage = "the weight for the style image") 309 private val styleWeight: Float = 1f 310 @Option(name = "--tv-weight", usage = "the magtitute on TV loss") 311 private val tvWeight: Float = 0.01f 312 @Option(name = "--max-num-epochs", usage = "the maximal number of training epochs") 313 private val maxNumEpochs: Int = 1000 314 @Option(name = "--max-long-edge", usage = "resize the content image") 315 private val maxLongEdge: Int = 600 316 @Option(name = "--lr", usage = "the initial learning rate") 317 private val lr: Float = 10f 318 @Option(name = "--gpu", usage = "which gpu card to use, -1 means using cpu") 319 private val gpu: Int = 0 320 @Option(name = "--output-dir", usage = "the output directory") 321 private val outputDir: String = null 322 @Option(name = "--save-epochs", usage = "save the output every n epochs") 323 private val saveEpochs: Int = 50 324 @Option(name = "--gaussian-radius", usage = "the gaussian blur filter radius") 325 private val gaussianRadius: Int = 1 326} 327