1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18library(mxnet) 19 20get_symbol <- function(num_classes = 1000) { 21 ## define alexnet 22 data = mx.symbol.Variable(name = "data") 23 # group 1 24 conv1_1 = mx.symbol.Convolution(data = data, kernel = c(3, 3), pad = c(1, 1), 25 num_filter = 64, name = "conv1_1") 26 relu1_1 = mx.symbol.Activation(data = conv1_1, act_type = "relu", name = "relu1_1") 27 pool1 = mx.symbol.Pooling(data = relu1_1, pool_type = "max", kernel = c(2, 2), 28 stride = c(2, 2), name = "pool1") 29 # group 2 30 conv2_1 = mx.symbol.Convolution(data = pool1, kernel = c(3, 3), pad = c(1, 1), 31 num_filter = 128, name = "conv2_1") 32 relu2_1 = mx.symbol.Activation(data = conv2_1, act_type = "relu", name = "relu2_1") 33 pool2 = mx.symbol.Pooling(data = relu2_1, pool_type = "max", kernel = c(2, 2), 34 stride = c(2, 2), name = "pool2") 35 # group 3 36 conv3_1 = mx.symbol.Convolution(data = pool2, kernel = c(3, 3), pad = c(1, 1), 37 num_filter = 256, name = "conv3_1") 38 relu3_1 = mx.symbol.Activation(data = conv3_1, act_type = "relu", name = "relu3_1") 39 conv3_2 = mx.symbol.Convolution(data = relu3_1, kernel = c(3, 3), pad = c(1, 1), 40 num_filter = 256, name = "conv3_2") 41 relu3_2 = mx.symbol.Activation(data = conv3_2, act_type = "relu", name = "relu3_2") 42 pool3 = mx.symbol.Pooling(data = relu3_2, pool_type = "max", kernel = c(2, 2), 43 stride = c(2, 2), name = "pool3") 44 # group 4 45 conv4_1 = mx.symbol.Convolution(data = pool3, kernel = c(3, 3), pad = c(1, 1), 46 num_filter = 512, name = "conv4_1") 47 relu4_1 = mx.symbol.Activation(data = conv4_1, act_type = "relu", name = "relu4_1") 48 conv4_2 = mx.symbol.Convolution(data = relu4_1, kernel = c(3, 3), pad = c(1, 1), 49 num_filter = 512, name = "conv4_2") 50 relu4_2 = mx.symbol.Activation(data = conv4_2, act_type = "relu", name = "relu4_2") 51 pool4 = mx.symbol.Pooling(data = relu4_2, pool_type = "max", 52 kernel = c(2, 2), stride = c(2, 2), name = "pool4") 53 # group 5 54 conv5_1 = mx.symbol.Convolution(data = pool4, kernel = c(3, 3), 55 pad = c(1, 1), num_filter = 512, name = "conv5_1") 56 relu5_1 = mx.symbol.Activation(data = conv5_1, act_type = "relu", name = "relu5_1") 57 conv5_2 = mx.symbol.Convolution(data = relu5_1, kernel = c(3, 3), 58 pad = c(1, 1), num_filter = 512, name = "conv5_2") 59 relu5_2 = mx.symbol.Activation(data = conv5_2, act_type = "relu", name = "relu5_2") 60 pool5 = mx.symbol.Pooling(data = relu5_2, pool_type = "max", 61 kernel = c(2, 2), stride = c(2, 2), name = "pool5") 62 # group 6 63 flatten = mx.symbol.Flatten(data = pool5, name = "flatten") 64 fc6 = mx.symbol.FullyConnected(data = flatten, num_hidden = 4096, name = "fc6") 65 relu6 = mx.symbol.Activation(data = fc6, act_type = "relu", name = "relu6") 66 drop6 = mx.symbol.Dropout(data = relu6, p = 0.5, name = "drop6") 67 # group 7 68 fc7 = mx.symbol.FullyConnected(data = drop6, num_hidden = 4096, name = "fc7") 69 relu7 = mx.symbol.Activation(data = fc7, act_type = "relu", name = "relu7") 70 drop7 = mx.symbol.Dropout(data = relu7, p = 0.5, name = "drop7") 71 # output 72 fc8 = mx.symbol.FullyConnected(data = drop7, num_hidden = num_classes, name = "fc8") 73 softmax = mx.symbol.SoftmaxOutput(data = fc8, name = 'softmax') 74 return(softmax) 75} 76