1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18library(mxnet)
19
20conv_factory <- function(data, num_filter, kernel, stride,
21                         pad, act_type = 'relu', conv_type = 0) {
22    if (conv_type == 0) {
23      conv = mx.symbol.Convolution(data = data, num_filter = num_filter,
24                                   kernel = kernel, stride = stride, pad = pad)
25      bn = mx.symbol.BatchNorm(data = conv)
26      act = mx.symbol.Activation(data = bn, act_type = act_type)
27      return(act)
28    } else if (conv_type == 1) {
29      conv = mx.symbol.Convolution(data = data, num_filter = num_filter,
30                                   kernel = kernel, stride = stride, pad = pad)
31      bn = mx.symbol.BatchNorm(data = conv)
32      return(bn)
33    }
34}
35
36residual_factory <- function(data, num_filter, dim_match) {
37  if (dim_match) {
38    identity_data = data
39    conv1 = conv_factory(data = data, num_filter = num_filter, kernel = c(3, 3),
40                         stride = c(1, 1), pad = c(1, 1), act_type = 'relu', conv_type = 0)
41
42    conv2 = conv_factory(data = conv1, num_filter = num_filter, kernel = c(3, 3),
43                         stride = c(1, 1), pad = c(1, 1), conv_type = 1)
44    new_data = identity_data + conv2
45    act = mx.symbol.Activation(data = new_data, act_type = 'relu')
46    return(act)
47  } else {
48    conv1 = conv_factory(data = data, num_filter = num_filter, kernel = c(3, 3),
49                         stride = c(2, 2), pad = c(1, 1), act_type = 'relu', conv_type = 0)
50    conv2 = conv_factory(data = conv1, num_filter = num_filter, kernel = c(3, 3),
51                         stride = c(1, 1), pad = c(1, 1), conv_type = 1)
52
53    # adopt project method in the paper when dimension increased
54    project_data = conv_factory(data = data, num_filter = num_filter, kernel = c(1, 1),
55                                stride = c(2, 2), pad = c(0, 0), conv_type = 1)
56    new_data = project_data + conv2
57    act = mx.symbol.Activation(data = new_data, act_type = 'relu')
58    return(act)
59  }
60}
61
62residual_net <- function(data, n) {
63  #fisrt 2n layers
64  for (i in 1:n) {
65    data = residual_factory(data = data, num_filter = 16, dim_match = TRUE)
66  }
67
68
69  #second 2n layers
70  for (i in 1:n) {
71    if (i == 1) {
72      data = residual_factory(data = data, num_filter = 32, dim_match = FALSE)
73    } else {
74      data = residual_factory(data = data, num_filter = 32, dim_match = TRUE)
75    }
76  }
77  #third 2n layers
78  for (i in 1:n) {
79    if (i == 1) {
80      data = residual_factory(data = data, num_filter = 64, dim_match = FALSE)
81    } else {
82      data = residual_factory(data = data, num_filter = 64, dim_match = TRUE)
83    }
84  }
85  return(data)
86}
87
88get_symbol <- function(num_classes = 10) {
89  conv <- conv_factory(data = mx.symbol.Variable(name = 'data'), num_filter = 16,
90                      kernel = c(3, 3), stride = c(1, 1), pad = c(1, 1),
91                      act_type = 'relu', conv_type = 0)
92  n <- 3 # set n = 3 means get a model with 3*6+2=20 layers, set n = 9 means 9*6+2=56 layers
93  resnet <- residual_net(conv, n) #
94  pool <- mx.symbol.Pooling(data = resnet, kernel = c(7, 7), pool_type = 'avg')
95  flatten <- mx.symbol.Flatten(data = pool, name = 'flatten')
96  fc <- mx.symbol.FullyConnected(data = flatten, num_hidden = num_classes, name = 'fc1')
97  softmax <- mx.symbol.SoftmaxOutput(data = fc, name = 'softmax')
98  return(softmax)
99}
100