1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnetexamples.neuralstyle.end2end 19 20import java.io.File 21 22import org.apache.mxnet.{Context, Executor, NDArray, ResourceScope, Shape, Symbol} 23import org.apache.mxnet.optimizer.SGD 24import org.kohsuke.args4j.{CmdLineParser, Option} 25import org.slf4j.LoggerFactory 26 27import scala.collection.JavaConverters._ 28import scala.util.Random 29 30 31object BoostTrain { 32 33 private val logger = LoggerFactory.getLogger(classOf[BoostTrain]) 34 35 def getTvGradExecutor(img: NDArray, ctx: Context, tvWeight: Float): Executor = { 36 // create TV gradient executor with input binded on img 37 if (tvWeight <= 0.0f) null 38 39 val nChannel = img.shape(1) 40 val sImg = Symbol.Variable("img") 41 val sKernel = Symbol.Variable("kernel") 42 val channels = Symbol.api.SliceChannel(data = Some(sImg), num_outputs = nChannel) 43 val toConcat = (0 until nChannel).map( i => 44 Symbol.api.Convolution(data = Some(channels.get(i)), weight = Some(sKernel), 45 num_filter = 1, kernel = Shape(3, 3), pad = Some(Shape(1, 1)), 46 no_bias = Some(true), stride = Some(Shape(1, 1))) 47 ).toArray 48 val out = Symbol.api.Concat(data = toConcat, num_args = toConcat.length) * tvWeight 49 val kernel = { 50 val tmp = NDArray.empty(Shape(1, 1, 3, 3), ctx) 51 tmp.set(Array[Float](0, -1, 0, -1, 4, -1, 0, -1, 0)) 52 tmp / 8.0f 53 } 54 out.bind(ctx, Map("img" -> img, "kernel" -> kernel)) 55 } 56 57 def runTraining(dataPath : String, vggModelPath: String, ctx : Context, 58 styleImage : String, saveModelPath : String) : Unit = { 59 ResourceScope.using() { 60 // params 61 val vggParams = NDArray.load2Map(vggModelPath) 62 val styleWeight = 1.2f 63 val contentWeight = 10f 64 val dShape = Shape(1, 3, 384, 384) 65 val clipNorm = 0.05f * dShape.product 66 val modelPrefix = "v3" 67 // init style 68 val styleNp = DataProcessing.preprocessStyleImage(styleImage, dShape, ctx) 69 var styleMod = Basic.getStyleModule("style", dShape, ctx, vggParams) 70 styleMod.forward(Array(styleNp)) 71 val styleArray = styleMod.getOutputs().map(_.copyTo(Context.cpu())) 72 styleMod.dispose() 73 styleMod = null 74 75 // content 76 val contentMod = Basic.getContentModule("content", dShape, ctx, vggParams) 77 78 // loss 79 val (loss, gScale) = Basic.getLossModule("loss", dShape, ctx, vggParams) 80 val extraArgs = (0 until styleArray.length) 81 .map(i => s"target_gram_$i" -> styleArray(i)).toMap 82 loss.setParams(extraArgs) 83 var gradArray = Array[NDArray]() 84 for (i <- 0 until styleArray.length) { 85 gradArray = gradArray :+ (NDArray.ones(Shape(1), ctx) * (styleWeight / gScale(i))) 86 } 87 gradArray = gradArray :+ (NDArray.ones(Shape(1), ctx) * contentWeight) 88 89 // generator 90 val gens = Array( 91 GenV4.getModule("g0", dShape, ctx), 92 GenV3.getModule("g1", dShape, ctx), 93 GenV3.getModule("g2", dShape, ctx), 94 GenV4.getModule("g3", dShape, ctx) 95 ) 96 gens.foreach { gen => 97 val opt = new SGD(learningRate = 1e-4f, 98 momentum = 0.9f, 99 wd = 5e-3f, 100 clipGradient = 5f) 101 gen.initOptimizer(opt) 102 } 103 104 var filelist = new File(dataPath).list().toList 105 val numImage = filelist.length 106 logger.info(s"Dataset size: $numImage") 107 108 val tvWeight = 1e-2f 109 110 val startEpoch = 0 111 val endEpoch = 3 112 113 for (k <- 0 until gens.length) { 114 val path = new File(s"${saveModelPath}/$k") 115 if (!path.exists()) path.mkdir() 116 } 117 118 // train 119 for (i <- startEpoch until endEpoch) { 120 ResourceScope.using() { 121 filelist = Random.shuffle(filelist) 122 for (idx <- filelist.indices) { 123 var dataArray = Array[NDArray]() 124 var lossGradArray = Array[NDArray]() 125 val data = 126 DataProcessing.preprocessContentImage(s"${dataPath}/${filelist(idx)}", dShape, ctx) 127 dataArray = dataArray :+ data 128 // get content 129 contentMod.forward(Array(data)) 130 // set target content 131 loss.setParams(Map("target_content" -> contentMod.getOutputs()(0))) 132 // gen_forward 133 for (k <- 0 until gens.length) { 134 gens(k).forward(dataArray.takeRight(1)) 135 dataArray = dataArray :+ gens(k).getOutputs()(0) 136 // loss forward 137 loss.forward(dataArray.takeRight(1)) 138 loss.backward(gradArray) 139 lossGradArray = lossGradArray :+ loss.getInputGrads()(0) 140 } 141 val grad = NDArray.zeros(data.shape, ctx) 142 for (k <- gens.length - 1 to 0 by -1) { 143 val tvGradExecutor = getTvGradExecutor(gens(k).getOutputs()(0), ctx, tvWeight) 144 tvGradExecutor.forward() 145 grad += lossGradArray(k) + tvGradExecutor.outputs(0) 146 val gNorm = NDArray.norm(grad) 147 if (gNorm.toScalar > clipNorm) { 148 grad *= clipNorm / gNorm.toScalar 149 } 150 gens(k).backward(Array(grad)) 151 gens(k).update() 152 gNorm.dispose() 153 tvGradExecutor.dispose() 154 } 155 grad.dispose() 156 if (idx % 20 == 0) { 157 logger.info(s"Epoch $i: Image $idx") 158 for (k <- 0 until gens.length) { 159 val n = NDArray.norm(gens(k).getInputGrads()(0)) 160 logger.info(s"Data Norm : ${n.toScalar / dShape.product}") 161 n.dispose() 162 } 163 } 164 if (idx % 1000 == 0) { 165 for (k <- 0 until gens.length) { 166 gens(k).saveParams( 167 s"${saveModelPath}/$k/${modelPrefix}_" + 168 s"${"%04d".format(i)}-${"%07d".format(idx)}.params") 169 } 170 } 171 data.dispose() 172 } 173 } 174 } 175 } 176 } 177 178 def main(args: Array[String]): Unit = { 179 val stin = new BoostTrain 180 val parser: CmdLineParser = new CmdLineParser(stin) 181 try { 182 parser.parseArgument(args.toList.asJava) 183 assert(stin.dataPath != null 184 && stin.vggModelPath != null 185 && stin.saveModelPath != null 186 && stin.styleImage != null) 187 188 val ctx = if (stin.gpu == -1) Context.cpu() else Context.gpu(stin.gpu) 189 runTraining(stin.dataPath, stin.vggModelPath, ctx, stin.styleImage, stin.saveModelPath) 190 } catch { 191 case ex: Exception => { 192 logger.error(ex.getMessage, ex) 193 parser.printUsage(System.err) 194 sys.exit(1) 195 } 196 } 197 } 198} 199 200class BoostTrain { 201 @Option(name = "--data-path", usage = "the input train data path") 202 private val dataPath: String = null 203 @Option(name = "--vgg-model-path", usage = "the pretrained model to use: ['vgg']") 204 private val vggModelPath: String = null 205 @Option(name = "--save-model-path", usage = "the save model path") 206 private val saveModelPath: String = null 207 @Option(name = "--style-image", usage = "the style image") 208 private val styleImage: String = null 209 @Option(name = "--gpu", usage = "which gpu card to use, default is -1, means using cpu") 210 private val gpu: Int = -1 211} 212