1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnetexamples.neuralstyle.end2end 19 20import org.apache.mxnet.{Context, ResourceScope, Shape} 21import org.kohsuke.args4j.{CmdLineParser, Option} 22import org.slf4j.LoggerFactory 23 24import scala.collection.JavaConverters._ 25 26object BoostInference { 27 28 private val logger = LoggerFactory.getLogger(classOf[BoostInference]) 29 30 def runInference(modelPath: String, outputPath: String, guassianRadius : Int, 31 inputImage : String, ctx : Context): Unit = { 32 ResourceScope.using() { 33 val dShape = Shape(1, 3, 480, 640) 34 val clipNorm = 1.0f * dShape.product 35 // generator 36 val gens = Array( 37 GenV4.getModule("g0", dShape, ctx, isTrain = false), 38 GenV3.getModule("g1", dShape, ctx, isTrain = false), 39 GenV3.getModule("g2", dShape, ctx, isTrain = false), 40 GenV4.getModule("g3", dShape, ctx, isTrain = false) 41 ) 42 gens.zipWithIndex.foreach { case (gen, i) => 43 gen.loadParams(s"$modelPath/$i/v3_0002-0026000.params") 44 } 45 46 val contentNp = 47 DataProcessing.preprocessContentImage(s"$inputImage", dShape, ctx) 48 var data = Array(contentNp) 49 for (i <- 0 until gens.length) { 50 ResourceScope.using() { 51 gens(i).forward(data.takeRight(1)) 52 val newImg = gens(i).getOutputs()(0) 53 data :+= newImg 54 DataProcessing.saveImage(newImg, s"$outputPath/out_$i.jpg", guassianRadius) 55 logger.info(s"Converted image: $outputPath/out_$i.jpg") 56 } 57 } 58 } 59 } 60 61 def main(args: Array[String]): Unit = { 62 val stce = new BoostInference 63 val parser: CmdLineParser = new CmdLineParser(stce) 64 try { 65 parser.parseArgument(args.toList.asJava) 66 assert(stce.modelPath != null 67 && stce.inputImage != null 68 && stce.outputPath != null) 69 70 val ctx = if (stce.gpu == -1) Context.cpu() else Context.gpu(stce.gpu) 71 72 runInference(stce.modelPath, stce.outputPath, stce.guassianRadius, stce.inputImage, ctx) 73 74 } catch { 75 case ex: Exception => { 76 logger.error(ex.getMessage, ex) 77 parser.printUsage(System.err) 78 sys.exit(1) 79 } 80 } 81 } 82} 83 84class BoostInference { 85 @Option(name = "--model-path", usage = "the saved model path") 86 private val modelPath: String = null 87 @Option(name = "--input-image", usage = "the style image") 88 private val inputImage: String = null 89 @Option(name = "--output-path", usage = "the output result path") 90 private val outputPath: String = null 91 @Option(name = "--gpu", usage = "which gpu card to use, default is -1, means using cpu") 92 private val gpu: Int = -1 93 @Option(name = "--guassian-radius", usage = "the gaussian blur filter radius") 94 private val guassianRadius: Int = 2 95} 96