1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnetexamples.neuralstyle.end2end
19
20import org.apache.mxnet.{Context, Shape, Symbol, Xavier}
21
22
23object GenV4 {
24
25  def Conv(data: Symbol, numFilter: Int, workspace : Long, kernel: (Int, Int) = (5, 5),
26           pad: (Int, Int) = (2, 2)): Symbol = {
27    val sym1 = Symbol.api.Convolution(data = Some(data), num_filter = numFilter,
28      kernel = Shape(kernel._1, kernel._2), workspace = Some(workspace),
29      pad = Some(Shape(pad._1, pad._2)), no_bias = Some(false))
30    val sym2 = Symbol.api.BatchNorm(data = Some(sym1), fix_gamma = Some(false))
31    val sym3 = Symbol.api.LeakyReLU(data = Some(sym2), act_type = Some("leaky"))
32    sym2.dispose()
33    sym1.dispose()
34    sym3
35  }
36
37  def getGenerator(prefix: String, imHw: (Int, Int)): Symbol = {
38    val data = Symbol.Variable(s"${prefix}_data")
39
40    var conv1_1 = Conv(data, 48, 4096)
41    val conv2_1 = Conv(conv1_1, 32, 4096)
42    var conv3_1 = Conv(conv2_1, 64, 4096, (3, 3), (1, 1))
43    var conv4_1 = Conv(conv3_1, 32, 4096)
44    var conv5_1 = Conv(conv4_1, 48, 4096)
45    var conv6_1 = Conv(conv5_1, 32, 4096)
46    var out = Symbol.api.Convolution(data = Some(conv6_1), num_filter = 3, kernel = Shape(3, 3),
47      pad = Some(Shape(1, 1)), no_bias = Some(true), workspace = Some(4096))
48    out = Symbol.api.BatchNorm(data = Some(out), fix_gamma = Some(false))
49    out = Symbol.api.Activation(data = Some(out), act_type = "tanh")
50    val rawOut = (out * 128) + 128
51    val norm = Symbol.api.SliceChannel(data = Some(rawOut), num_outputs = 3)
52    val rCh = norm.get(0) - 123.68f
53    val gCh = norm.get(1) - 116.779f
54    val bCh = norm.get(2) - 103.939f
55    val normOut = Symbol.api.Concat(data = Array(rCh, gCh, bCh), num_args = 3)
56    normOut * 0.4f + data * 0.6f
57  }
58
59  def getModule(prefix: String, dShape: Shape, ctx: Context, isTrain: Boolean = true): Module = {
60    val sym = getGenerator(prefix, (dShape(2), dShape(3)))
61    val (dataShapes, forTraining, inputsNeedGrad) = {
62      val dataShape = Map(s"${prefix}_data" -> dShape)
63      if (isTrain) (dataShape, true, true)
64      else (dataShape, false, false)
65    }
66    val mod = new Module(symbol = sym, context = ctx,
67      dataShapes = dataShapes,
68      initializer = new Xavier(magnitude = 2f),
69      forTraining = forTraining, inputsNeedGrad = inputsNeedGrad)
70    mod
71  }
72}
73