1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18library(mxnet) 19 20conv_factory <- function(data, num_filter, kernel, stride, 21 pad, act_type = 'relu', conv_type = 0) { 22 if (conv_type == 0) { 23 conv = mx.symbol.Convolution(data = data, num_filter = num_filter, 24 kernel = kernel, stride = stride, pad = pad) 25 bn = mx.symbol.BatchNorm(data = conv) 26 act = mx.symbol.Activation(data = bn, act_type = act_type) 27 return(act) 28 } else if (conv_type == 1) { 29 conv = mx.symbol.Convolution(data = data, num_filter = num_filter, 30 kernel = kernel, stride = stride, pad = pad) 31 bn = mx.symbol.BatchNorm(data = conv) 32 return(bn) 33 } 34} 35 36residual_factory <- function(data, num_filter, dim_match) { 37 if (dim_match) { 38 identity_data = data 39 conv1 = conv_factory(data = data, num_filter = num_filter, kernel = c(3, 3), 40 stride = c(1, 1), pad = c(1, 1), act_type = 'relu', conv_type = 0) 41 42 conv2 = conv_factory(data = conv1, num_filter = num_filter, kernel = c(3, 3), 43 stride = c(1, 1), pad = c(1, 1), conv_type = 1) 44 new_data = identity_data + conv2 45 act = mx.symbol.Activation(data = new_data, act_type = 'relu') 46 return(act) 47 } else { 48 conv1 = conv_factory(data = data, num_filter = num_filter, kernel = c(3, 3), 49 stride = c(2, 2), pad = c(1, 1), act_type = 'relu', conv_type = 0) 50 conv2 = conv_factory(data = conv1, num_filter = num_filter, kernel = c(3, 3), 51 stride = c(1, 1), pad = c(1, 1), conv_type = 1) 52 53 # adopt project method in the paper when dimension increased 54 project_data = conv_factory(data = data, num_filter = num_filter, kernel = c(1, 1), 55 stride = c(2, 2), pad = c(0, 0), conv_type = 1) 56 new_data = project_data + conv2 57 act = mx.symbol.Activation(data = new_data, act_type = 'relu') 58 return(act) 59 } 60} 61 62residual_net <- function(data, n) { 63 #fisrt 2n layers 64 for (i in 1:n) { 65 data = residual_factory(data = data, num_filter = 16, dim_match = TRUE) 66 } 67 68 69 #second 2n layers 70 for (i in 1:n) { 71 if (i == 1) { 72 data = residual_factory(data = data, num_filter = 32, dim_match = FALSE) 73 } else { 74 data = residual_factory(data = data, num_filter = 32, dim_match = TRUE) 75 } 76 } 77 #third 2n layers 78 for (i in 1:n) { 79 if (i == 1) { 80 data = residual_factory(data = data, num_filter = 64, dim_match = FALSE) 81 } else { 82 data = residual_factory(data = data, num_filter = 64, dim_match = TRUE) 83 } 84 } 85 return(data) 86} 87 88get_symbol <- function(num_classes = 10) { 89 conv <- conv_factory(data = mx.symbol.Variable(name = 'data'), num_filter = 16, 90 kernel = c(3, 3), stride = c(1, 1), pad = c(1, 1), 91 act_type = 'relu', conv_type = 0) 92 n <- 3 # set n = 3 means get a model with 3*6+2=20 layers, set n = 9 means 9*6+2=56 layers 93 resnet <- residual_net(conv, n) # 94 pool <- mx.symbol.Pooling(data = resnet, kernel = c(7, 7), pool_type = 'avg') 95 flatten <- mx.symbol.Flatten(data = pool, name = 'flatten') 96 fc <- mx.symbol.FullyConnected(data = flatten, num_hidden = num_classes, name = 'fc1') 97 softmax <- mx.symbol.SoftmaxOutput(data = fc, name = 'softmax') 98 return(softmax) 99} 100