1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnetexamples.neuralstyle.end2end 19 20import org.apache.mxnet.{Context, Shape, Symbol, Xavier} 21 22 23object GenV4 { 24 25 def Conv(data: Symbol, numFilter: Int, workspace : Long, kernel: (Int, Int) = (5, 5), 26 pad: (Int, Int) = (2, 2)): Symbol = { 27 val sym1 = Symbol.api.Convolution(data = Some(data), num_filter = numFilter, 28 kernel = Shape(kernel._1, kernel._2), workspace = Some(workspace), 29 pad = Some(Shape(pad._1, pad._2)), no_bias = Some(false)) 30 val sym2 = Symbol.api.BatchNorm(data = Some(sym1), fix_gamma = Some(false)) 31 val sym3 = Symbol.api.LeakyReLU(data = Some(sym2), act_type = Some("leaky")) 32 sym2.dispose() 33 sym1.dispose() 34 sym3 35 } 36 37 def getGenerator(prefix: String, imHw: (Int, Int)): Symbol = { 38 val data = Symbol.Variable(s"${prefix}_data") 39 40 var conv1_1 = Conv(data, 48, 4096) 41 val conv2_1 = Conv(conv1_1, 32, 4096) 42 var conv3_1 = Conv(conv2_1, 64, 4096, (3, 3), (1, 1)) 43 var conv4_1 = Conv(conv3_1, 32, 4096) 44 var conv5_1 = Conv(conv4_1, 48, 4096) 45 var conv6_1 = Conv(conv5_1, 32, 4096) 46 var out = Symbol.api.Convolution(data = Some(conv6_1), num_filter = 3, kernel = Shape(3, 3), 47 pad = Some(Shape(1, 1)), no_bias = Some(true), workspace = Some(4096)) 48 out = Symbol.api.BatchNorm(data = Some(out), fix_gamma = Some(false)) 49 out = Symbol.api.Activation(data = Some(out), act_type = "tanh") 50 val rawOut = (out * 128) + 128 51 val norm = Symbol.api.SliceChannel(data = Some(rawOut), num_outputs = 3) 52 val rCh = norm.get(0) - 123.68f 53 val gCh = norm.get(1) - 116.779f 54 val bCh = norm.get(2) - 103.939f 55 val normOut = Symbol.api.Concat(data = Array(rCh, gCh, bCh), num_args = 3) 56 normOut * 0.4f + data * 0.6f 57 } 58 59 def getModule(prefix: String, dShape: Shape, ctx: Context, isTrain: Boolean = true): Module = { 60 val sym = getGenerator(prefix, (dShape(2), dShape(3))) 61 val (dataShapes, forTraining, inputsNeedGrad) = { 62 val dataShape = Map(s"${prefix}_data" -> dShape) 63 if (isTrain) (dataShape, true, true) 64 else (dataShape, false, false) 65 } 66 val mod = new Module(symbol = sym, context = ctx, 67 dataShapes = dataShapes, 68 initializer = new Xavier(magnitude = 2f), 69 forTraining = forTraining, inputsNeedGrad = inputsNeedGrad) 70 mod 71 } 72} 73