1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18require(argparse) 19require(mxnet) 20 21download_ <- function(data_dir) { 22 dir.create(data_dir, showWarnings = FALSE) 23 setwd(data_dir) 24 if ((!file.exists('train-images-idx3-ubyte')) || 25 (!file.exists('train-labels-idx1-ubyte')) || 26 (!file.exists('t10k-images-idx3-ubyte')) || 27 (!file.exists('t10k-labels-idx1-ubyte'))) { 28 download.file(url='http://data.mxnet.io/mxnet/data/mnist.zip', 29 destfile='mnist.zip', method='wget') 30 unzip("mnist.zip") 31 file.remove("mnist.zip") 32 } 33 setwd("..") 34} 35 36# multi-layer perceptron 37get_mlp <- function() { 38 data <- mx.symbol.Variable('data') 39 fc1 <- mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128) 40 act1 <- mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") 41 fc2 <- mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64) 42 act2 <- mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") 43 fc3 <- mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10) 44 mlp <- mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax') 45 mlp 46} 47 48# LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick 49# Haffner. "Gradient-based learning applied to document recognition." 50# Proceedings of the IEEE (1998) 51get_lenet <- function() { 52 data <- mx.symbol.Variable('data') 53 # first conv 54 conv1 <- mx.symbol.Convolution(data=data, kernel=c(5,5), num_filter=20) 55 tanh1 <- mx.symbol.Activation(data=conv1, act_type="tanh") 56 pool1 <- mx.symbol.Pooling(data=tanh1, pool_type="max", 57 kernel=c(2,2), stride=c(2,2)) 58 # second conv 59 conv2 <- mx.symbol.Convolution(data=pool1, kernel=c(5,5), num_filter=50) 60 tanh2 <- mx.symbol.Activation(data=conv2, act_type="tanh") 61 pool2 <- mx.symbol.Pooling(data=tanh2, pool_type="max", 62 kernel=c(2,2), stride=c(2,2)) 63 # first fullc 64 flatten <- mx.symbol.Flatten(data=pool2) 65 fc1 <- mx.symbol.FullyConnected(data=flatten, num_hidden=500) 66 tanh3 <- mx.symbol.Activation(data=fc1, act_type="tanh") 67 # second fullc 68 fc2 <- mx.symbol.FullyConnected(data=tanh3, num_hidden=10) 69 # loss 70 lenet <- mx.symbol.SoftmaxOutput(data=fc2, name='softmax') 71 lenet 72} 73 74get_iterator <- function(data_shape) { 75 get_iterator_impl <- function(args) { 76 data_dir = args$data_dir 77 if (!grepl('://', args$data_dir)) 78 download_(args$data_dir) 79 flat <- TRUE 80 if (length(data_shape) == 3) flat <- FALSE 81 82 train = mx.io.MNISTIter( 83 image = paste0(data_dir, "train-images-idx3-ubyte"), 84 label = paste0(data_dir, "train-labels-idx1-ubyte"), 85 input_shape = data_shape, 86 batch_size = args$batch_size, 87 shuffle = TRUE, 88 flat = flat) 89 90 val = mx.io.MNISTIter( 91 image = paste0(data_dir, "t10k-images-idx3-ubyte"), 92 label = paste0(data_dir, "t10k-labels-idx1-ubyte"), 93 input_shape = data_shape, 94 batch_size = args$batch_size, 95 flat = flat) 96 97 ret = list(train=train, value=val) 98 } 99 get_iterator_impl 100} 101 102parse_args <- function() { 103 parser <- ArgumentParser(description='train an image classifer on mnist') 104 parser$add_argument('--network', type='character', default='mlp', 105 choices = c('mlp', 'lenet'), 106 help = 'the cnn to use') 107 parser$add_argument('--data-dir', type='character', default='mnist/', 108 help='the input data directory') 109 parser$add_argument('--gpus', type='character', 110 help='the gpus will be used, e.g "0,1,2,3"') 111 parser$add_argument('--batch-size', type='integer', default=128, 112 help='the batch size') 113 parser$add_argument('--lr', type='double', default=.05, 114 help='the initial learning rate') 115 parser$add_argument('--mom', type='double', default=.9, 116 help='momentum for sgd') 117 parser$add_argument('--model-prefix', type='character', 118 help='the prefix of the model to load/save') 119 parser$add_argument('--num-round', type='integer', default=10, 120 help='the number of iterations over training data to train the model') 121 parser$add_argument('--kv-store', type='character', default='local', 122 help='the kvstore type') 123 124 parser$parse_args() 125} 126 127args = parse_args() 128if (args$network == 'mlp') { 129 data_shape <- c(784) 130 net <- get_mlp() 131} else { 132 data_shape <- c(28, 28, 1) 133 net <- get_lenet() 134} 135 136# train 137data_loader <- get_iterator(data_shape) 138data <- data_loader(args) 139train <- data$train 140val <- data$value 141 142if (is.null(args$gpus)) { 143 devs <- mx.cpu() 144} else { 145 devs <- lapply(unlist(strsplit(args$gpus, ",")), function(i) { 146 mx.gpu(as.integer(i)) 147 }) 148} 149 150mx.set.seed(0) 151 152model <- mx.model.FeedForward.create( 153 X = train, 154 eval.data = val, 155 ctx = devs, 156 symbol = net, 157 num.round = args$num_round, 158 array.batch.size = args$batch_size, 159 learning.rate = args$lr, 160 momentum = args$mom, 161 eval.metric = mx.metric.accuracy, 162 initializer = mx.init.uniform(0.07), 163 batch.end.callback = mx.callback.log.train.metric(100)) 164