1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnetexamples.neuralstyle.end2end 19 20import org.apache.mxnet.{Context, Shape, Symbol, Xavier} 21 22 23object GenV3 { 24 def Conv(data: Symbol, numFilter: Int, kernel: (Int, Int) = (5, 5), 25 pad: (Int, Int) = (2, 2), stride: (Int, Int) = (2, 2)): Symbol = { 26 val sym1 = Symbol.api.Convolution(data = Some(data), num_filter = numFilter, 27 kernel = Shape(kernel._1, kernel._2), stride = Some(Shape(stride._1, stride._2)), 28 pad = Some(Shape(pad._1, pad._2)), no_bias = Some(false)) 29 val sym2 = Symbol.api.BatchNorm(data = Some(sym1), fix_gamma = Some(false)) 30 val sym3 = Symbol.api.LeakyReLU(data = Some(sym2), act_type = Some("leaky")) 31 sym2.dispose() 32 sym1.dispose() 33 sym3 34 } 35 36 def Deconv(data: Symbol, numFilter: Int, imHw: (Int, Int), 37 kernel: (Int, Int) = (7, 7), pad: (Int, Int) = (2, 2), stride: (Int, Int) = (2, 2), 38 crop: Boolean = true, out: Boolean = false): Symbol = { 39 var sym = Symbol.api.Deconvolution(data = Some(data), num_filter = numFilter, 40 kernel = Shape(kernel._1, kernel._2), stride = Some(Shape(stride._1, stride._2)), 41 pad = Some(Shape(pad._1, pad._2)), no_bias = Some(true)) 42 if (crop) sym = Symbol.api.Crop(data = Array(sym), offset = Some(Shape(1, 1)), 43 h_w = Some(Shape(imHw._1, imHw._2)), num_args = 1) 44 sym = Symbol.api.BatchNorm(data = Some(sym), fix_gamma = Some(false)) 45 if (out == false) Symbol.api.LeakyReLU(data = Some(sym), act_type = Some("leaky")) 46 else Symbol.api.Activation(data = Some(sym), act_type = "tanh") 47 } 48 49 def getGenerator(prefix: String, imHw: (Int, Int)): Symbol = { 50 val data = Symbol.Variable(s"${prefix}_data") 51 val conv1 = Conv(data, 64) // 192 52 val conv1_1 = Conv(conv1, 48, kernel = (3, 3), pad = (1, 1), stride = (1, 1)) 53 val conv2 = Conv(conv1_1, 128) // 96 54 val conv2_1 = Conv(conv2, 96, kernel = (3, 3), pad = (1, 1), stride = (1, 1)) 55 val conv3 = Conv(conv2_1, 256) // 48 56 val conv3_1 = Conv(conv3, 192, kernel = (3, 3), pad = (1, 1), stride = (1, 1)) 57 val deconv1 = Deconv(conv3_1, 128, (imHw._1 / 4, imHw._2 / 4)) + conv2 58 val conv4_1 = Conv(deconv1, 160, kernel = (3, 3), pad = (1, 1), stride = (1, 1)) 59 val deconv2 = Deconv(conv4_1, 64, (imHw._1 / 2, imHw._2 / 2)) + conv1 60 val conv5_1 = Conv(deconv2, 96, kernel = (3, 3), pad = (1, 1), stride = (1, 1)) 61 val deconv3 = Deconv(conv5_1, 3, imHw, kernel = (8, 8), pad = (3, 3), out = true, crop = false) 62 val rawOut = (deconv3 * 128) + 128 63 val norm = Symbol.api.SliceChannel(data = Some(rawOut), num_outputs = 3) 64 val rCh = norm.get(0) - 123.68f 65 val gCh = norm.get(1) - 116.779f 66 val bCh = norm.get(2) - 103.939f 67 val normOut = Symbol.api.Concat(data = Array(rCh, gCh, bCh), num_args = 3) 68 normOut * 0.4f + data * 0.6f 69 } 70 71 def getModule(prefix: String, dShape: Shape, ctx: Context, isTrain: Boolean = true): Module = { 72 val sym = getGenerator(prefix, (dShape(2), dShape(3))) 73 val (dataShapes, forTraining, inputsNeedGrad) = { 74 val dataShape = Map(s"${prefix}_data" -> dShape) 75 if (isTrain) (dataShape, true, true) 76 else (dataShape, false, false) 77 } 78 val mod = new Module(symbol = sym, context = ctx, 79 dataShapes = dataShapes, 80 initializer = new Xavier(magnitude = 2f), 81 forTraining = forTraining, inputsNeedGrad = inputsNeedGrad) 82 mod 83 } 84} 85