1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnetexamples.neuralstyle.end2end
19
20import org.apache.mxnet.{Context, Shape, Symbol, Xavier}
21
22
23object GenV3 {
24  def Conv(data: Symbol, numFilter: Int, kernel: (Int, Int) = (5, 5),
25           pad: (Int, Int) = (2, 2), stride: (Int, Int) = (2, 2)): Symbol = {
26    val sym1 = Symbol.api.Convolution(data = Some(data), num_filter = numFilter,
27      kernel = Shape(kernel._1, kernel._2), stride = Some(Shape(stride._1, stride._2)),
28      pad = Some(Shape(pad._1, pad._2)), no_bias = Some(false))
29    val sym2 = Symbol.api.BatchNorm(data = Some(sym1), fix_gamma = Some(false))
30    val sym3 = Symbol.api.LeakyReLU(data = Some(sym2), act_type = Some("leaky"))
31    sym2.dispose()
32    sym1.dispose()
33    sym3
34  }
35
36  def Deconv(data: Symbol, numFilter: Int, imHw: (Int, Int),
37             kernel: (Int, Int) = (7, 7), pad: (Int, Int) = (2, 2), stride: (Int, Int) = (2, 2),
38             crop: Boolean = true, out: Boolean = false): Symbol = {
39    var sym = Symbol.api.Deconvolution(data = Some(data), num_filter = numFilter,
40      kernel = Shape(kernel._1, kernel._2), stride = Some(Shape(stride._1, stride._2)),
41      pad = Some(Shape(pad._1, pad._2)), no_bias = Some(true))
42    if (crop) sym = Symbol.api.Crop(data = Array(sym), offset = Some(Shape(1, 1)),
43      h_w = Some(Shape(imHw._1, imHw._2)), num_args = 1)
44    sym = Symbol.api.BatchNorm(data = Some(sym), fix_gamma = Some(false))
45    if (out == false) Symbol.api.LeakyReLU(data = Some(sym), act_type = Some("leaky"))
46    else Symbol.api.Activation(data = Some(sym), act_type = "tanh")
47  }
48
49  def getGenerator(prefix: String, imHw: (Int, Int)): Symbol = {
50    val data = Symbol.Variable(s"${prefix}_data")
51    val conv1 = Conv(data, 64) // 192
52    val conv1_1 = Conv(conv1, 48, kernel = (3, 3), pad = (1, 1), stride = (1, 1))
53    val conv2 = Conv(conv1_1, 128) // 96
54    val conv2_1 = Conv(conv2, 96, kernel = (3, 3), pad = (1, 1), stride = (1, 1))
55    val conv3 = Conv(conv2_1, 256) // 48
56    val conv3_1 = Conv(conv3, 192, kernel = (3, 3), pad = (1, 1), stride = (1, 1))
57    val deconv1 = Deconv(conv3_1, 128, (imHw._1 / 4, imHw._2 / 4)) + conv2
58    val conv4_1 = Conv(deconv1, 160, kernel = (3, 3), pad = (1, 1), stride = (1, 1))
59    val deconv2 = Deconv(conv4_1, 64, (imHw._1 / 2, imHw._2 / 2)) + conv1
60    val conv5_1 = Conv(deconv2, 96, kernel = (3, 3), pad = (1, 1), stride = (1, 1))
61    val deconv3 = Deconv(conv5_1, 3, imHw, kernel = (8, 8), pad = (3, 3), out = true, crop = false)
62    val rawOut = (deconv3 * 128) + 128
63    val norm = Symbol.api.SliceChannel(data = Some(rawOut), num_outputs = 3)
64    val rCh = norm.get(0) - 123.68f
65    val gCh = norm.get(1) - 116.779f
66    val bCh = norm.get(2) - 103.939f
67    val normOut = Symbol.api.Concat(data = Array(rCh, gCh, bCh), num_args = 3)
68    normOut * 0.4f + data * 0.6f
69  }
70
71  def getModule(prefix: String, dShape: Shape, ctx: Context, isTrain: Boolean = true): Module = {
72    val sym = getGenerator(prefix, (dShape(2), dShape(3)))
73    val (dataShapes, forTraining, inputsNeedGrad) = {
74      val dataShape = Map(s"${prefix}_data" -> dShape)
75      if (isTrain) (dataShape, true, true)
76      else (dataShape, false, false)
77    }
78    val mod = new Module(symbol = sym, context = ctx,
79      dataShapes = dataShapes,
80      initializer = new Xavier(magnitude = 2f),
81      forTraining = forTraining, inputsNeedGrad = inputsNeedGrad)
82    mod
83  }
84}
85