1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnetexamples.neuralstyle.end2end
19
20import org.apache.mxnet.{Context, ResourceScope, Shape}
21import org.kohsuke.args4j.{CmdLineParser, Option}
22import org.slf4j.LoggerFactory
23
24import scala.collection.JavaConverters._
25
26object BoostInference {
27
28  private val logger = LoggerFactory.getLogger(classOf[BoostInference])
29
30  def runInference(modelPath: String, outputPath: String, guassianRadius : Int,
31                   inputImage : String, ctx : Context): Unit = {
32    ResourceScope.using() {
33      val dShape = Shape(1, 3, 480, 640)
34      val clipNorm = 1.0f * dShape.product
35      // generator
36      val gens = Array(
37        GenV4.getModule("g0", dShape, ctx, isTrain = false),
38        GenV3.getModule("g1", dShape, ctx, isTrain = false),
39        GenV3.getModule("g2", dShape, ctx, isTrain = false),
40        GenV4.getModule("g3", dShape, ctx, isTrain = false)
41      )
42      gens.zipWithIndex.foreach { case (gen, i) =>
43        gen.loadParams(s"$modelPath/$i/v3_0002-0026000.params")
44      }
45
46      val contentNp =
47        DataProcessing.preprocessContentImage(s"$inputImage", dShape, ctx)
48      var data = Array(contentNp)
49      for (i <- 0 until gens.length) {
50        ResourceScope.using() {
51          gens(i).forward(data.takeRight(1))
52          val newImg = gens(i).getOutputs()(0)
53          data :+= newImg
54          DataProcessing.saveImage(newImg, s"$outputPath/out_$i.jpg", guassianRadius)
55          logger.info(s"Converted image: $outputPath/out_$i.jpg")
56        }
57      }
58    }
59  }
60
61  def main(args: Array[String]): Unit = {
62    val stce = new BoostInference
63    val parser: CmdLineParser = new CmdLineParser(stce)
64    try {
65      parser.parseArgument(args.toList.asJava)
66      assert(stce.modelPath != null
67          && stce.inputImage != null
68          && stce.outputPath != null)
69
70      val ctx = if (stce.gpu == -1) Context.cpu() else Context.gpu(stce.gpu)
71
72      runInference(stce.modelPath, stce.outputPath, stce.guassianRadius, stce.inputImage, ctx)
73
74    } catch {
75      case ex: Exception => {
76        logger.error(ex.getMessage, ex)
77        parser.printUsage(System.err)
78        sys.exit(1)
79      }
80    }
81  }
82}
83
84class BoostInference {
85  @Option(name = "--model-path", usage = "the saved model path")
86  private val modelPath: String = null
87  @Option(name = "--input-image", usage = "the style image")
88  private val inputImage: String = null
89  @Option(name = "--output-path", usage = "the output result path")
90  private val outputPath: String = null
91  @Option(name = "--gpu", usage = "which gpu card to use, default is -1, means using cpu")
92  private val gpu: Int = -1
93  @Option(name = "--guassian-radius", usage = "the gaussian blur filter radius")
94  private val guassianRadius: Int = 2
95}
96