1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18library(mxnet)
19
20get_symbol <- function(num_classes = 1000) {
21  ## define alexnet
22  data = mx.symbol.Variable(name = "data")
23  # group 1
24  conv1_1 = mx.symbol.Convolution(data = data, kernel = c(3, 3), pad = c(1, 1),
25                                  num_filter = 64, name = "conv1_1")
26  relu1_1 = mx.symbol.Activation(data = conv1_1, act_type = "relu", name = "relu1_1")
27  pool1 = mx.symbol.Pooling(data = relu1_1, pool_type = "max", kernel = c(2, 2),
28                            stride = c(2, 2), name = "pool1")
29  # group 2
30  conv2_1 = mx.symbol.Convolution(data = pool1, kernel = c(3, 3), pad = c(1, 1),
31                                  num_filter = 128, name = "conv2_1")
32  relu2_1 = mx.symbol.Activation(data = conv2_1, act_type = "relu", name = "relu2_1")
33  pool2 = mx.symbol.Pooling(data = relu2_1, pool_type = "max", kernel = c(2, 2),
34                            stride = c(2, 2), name = "pool2")
35  # group 3
36  conv3_1 = mx.symbol.Convolution(data = pool2, kernel = c(3, 3), pad = c(1, 1),
37                                  num_filter = 256, name = "conv3_1")
38  relu3_1 = mx.symbol.Activation(data = conv3_1, act_type = "relu", name = "relu3_1")
39  conv3_2 = mx.symbol.Convolution(data = relu3_1, kernel = c(3, 3), pad = c(1, 1),
40                                  num_filter = 256, name = "conv3_2")
41  relu3_2 = mx.symbol.Activation(data = conv3_2, act_type = "relu", name = "relu3_2")
42  pool3 = mx.symbol.Pooling(data = relu3_2, pool_type = "max", kernel = c(2, 2),
43                            stride = c(2, 2), name = "pool3")
44  # group 4
45  conv4_1 = mx.symbol.Convolution(data = pool3, kernel = c(3, 3), pad = c(1, 1),
46                                  num_filter = 512, name = "conv4_1")
47  relu4_1 = mx.symbol.Activation(data = conv4_1, act_type = "relu", name = "relu4_1")
48  conv4_2 = mx.symbol.Convolution(data = relu4_1, kernel = c(3, 3), pad = c(1, 1),
49                                  num_filter = 512, name = "conv4_2")
50  relu4_2 = mx.symbol.Activation(data = conv4_2, act_type = "relu", name = "relu4_2")
51  pool4 = mx.symbol.Pooling(data = relu4_2, pool_type = "max",
52                            kernel = c(2, 2), stride = c(2, 2), name = "pool4")
53  # group 5
54  conv5_1 = mx.symbol.Convolution(data = pool4, kernel = c(3, 3),
55                                  pad = c(1, 1), num_filter = 512, name = "conv5_1")
56  relu5_1 = mx.symbol.Activation(data = conv5_1, act_type = "relu", name = "relu5_1")
57  conv5_2 = mx.symbol.Convolution(data = relu5_1, kernel = c(3, 3),
58                                  pad = c(1, 1), num_filter = 512, name = "conv5_2")
59  relu5_2 = mx.symbol.Activation(data = conv5_2, act_type = "relu", name = "relu5_2")
60  pool5 = mx.symbol.Pooling(data = relu5_2, pool_type = "max",
61                            kernel = c(2, 2), stride = c(2, 2), name = "pool5")
62  # group 6
63  flatten = mx.symbol.Flatten(data = pool5, name = "flatten")
64  fc6 = mx.symbol.FullyConnected(data = flatten, num_hidden = 4096, name = "fc6")
65  relu6 = mx.symbol.Activation(data = fc6, act_type = "relu", name = "relu6")
66  drop6 = mx.symbol.Dropout(data = relu6, p = 0.5, name = "drop6")
67  # group 7
68  fc7 = mx.symbol.FullyConnected(data = drop6, num_hidden = 4096, name = "fc7")
69  relu7 = mx.symbol.Activation(data = fc7, act_type = "relu", name = "relu7")
70  drop7 = mx.symbol.Dropout(data = relu7, p = 0.5, name = "drop7")
71  # output
72  fc8 = mx.symbol.FullyConnected(data = drop7, num_hidden = num_classes, name = "fc8")
73  softmax = mx.symbol.SoftmaxOutput(data = fc8, name = 'softmax')
74  return(softmax)
75}
76