1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18require(argparse)
19require(mxnet)
20
21download_ <- function(data_dir) {
22    dir.create(data_dir, showWarnings = FALSE)
23    setwd(data_dir)
24    if ((!file.exists('train-images-idx3-ubyte')) ||
25        (!file.exists('train-labels-idx1-ubyte')) ||
26        (!file.exists('t10k-images-idx3-ubyte')) ||
27        (!file.exists('t10k-labels-idx1-ubyte'))) {
28        download.file(url='http://data.mxnet.io/mxnet/data/mnist.zip',
29                      destfile='mnist.zip', method='wget')
30        unzip("mnist.zip")
31        file.remove("mnist.zip")
32    }
33    setwd("..")
34}
35
36# multi-layer perceptron
37get_mlp <- function() {
38    data <- mx.symbol.Variable('data')
39    fc1  <- mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
40    act1 <- mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
41    fc2  <- mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
42    act2 <- mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
43    fc3  <- mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
44    mlp  <- mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')
45    mlp
46}
47
48# LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick
49# Haffner. "Gradient-based learning applied to document recognition."
50# Proceedings of the IEEE (1998)
51get_lenet <- function() {
52    data <- mx.symbol.Variable('data')
53    # first conv
54    conv1 <- mx.symbol.Convolution(data=data, kernel=c(5,5), num_filter=20)
55    tanh1 <- mx.symbol.Activation(data=conv1, act_type="tanh")
56    pool1 <- mx.symbol.Pooling(data=tanh1, pool_type="max",
57                               kernel=c(2,2), stride=c(2,2))
58    # second conv
59    conv2 <- mx.symbol.Convolution(data=pool1, kernel=c(5,5), num_filter=50)
60    tanh2 <- mx.symbol.Activation(data=conv2, act_type="tanh")
61    pool2 <- mx.symbol.Pooling(data=tanh2, pool_type="max",
62                              kernel=c(2,2), stride=c(2,2))
63    # first fullc
64    flatten <- mx.symbol.Flatten(data=pool2)
65    fc1 <- mx.symbol.FullyConnected(data=flatten, num_hidden=500)
66    tanh3 <- mx.symbol.Activation(data=fc1, act_type="tanh")
67    # second fullc
68    fc2 <- mx.symbol.FullyConnected(data=tanh3, num_hidden=10)
69    # loss
70    lenet <- mx.symbol.SoftmaxOutput(data=fc2, name='softmax')
71    lenet
72}
73
74get_iterator <- function(data_shape) {
75    get_iterator_impl <- function(args) {
76        data_dir = args$data_dir
77        if (!grepl('://', args$data_dir))
78            download_(args$data_dir)
79        flat <- TRUE
80        if (length(data_shape) == 3) flat <- FALSE
81
82        train           = mx.io.MNISTIter(
83            image       = paste0(data_dir, "train-images-idx3-ubyte"),
84            label       = paste0(data_dir, "train-labels-idx1-ubyte"),
85            input_shape = data_shape,
86            batch_size  = args$batch_size,
87            shuffle     = TRUE,
88            flat        = flat)
89
90        val = mx.io.MNISTIter(
91            image       = paste0(data_dir, "t10k-images-idx3-ubyte"),
92            label       = paste0(data_dir, "t10k-labels-idx1-ubyte"),
93            input_shape = data_shape,
94            batch_size  = args$batch_size,
95            flat        = flat)
96
97        ret = list(train=train, value=val)
98    }
99    get_iterator_impl
100}
101
102parse_args <- function() {
103    parser <- ArgumentParser(description='train an image classifer on mnist')
104    parser$add_argument('--network', type='character', default='mlp',
105                        choices = c('mlp', 'lenet'),
106                        help = 'the cnn to use')
107    parser$add_argument('--data-dir', type='character', default='mnist/',
108                        help='the input data directory')
109    parser$add_argument('--gpus', type='character',
110                        help='the gpus will be used, e.g "0,1,2,3"')
111    parser$add_argument('--batch-size', type='integer', default=128,
112                        help='the batch size')
113    parser$add_argument('--lr', type='double', default=.05,
114                        help='the initial learning rate')
115    parser$add_argument('--mom', type='double', default=.9,
116                        help='momentum for sgd')
117    parser$add_argument('--model-prefix', type='character',
118                        help='the prefix of the model to load/save')
119    parser$add_argument('--num-round', type='integer', default=10,
120                        help='the number of iterations over training data to train the model')
121    parser$add_argument('--kv-store', type='character', default='local',
122                        help='the kvstore type')
123
124    parser$parse_args()
125}
126
127args = parse_args()
128if (args$network == 'mlp') {
129    data_shape <- c(784)
130    net <- get_mlp()
131} else {
132    data_shape <- c(28, 28, 1)
133    net <- get_lenet()
134}
135
136# train
137data_loader <- get_iterator(data_shape)
138data <- data_loader(args)
139train <- data$train
140val <- data$value
141
142if (is.null(args$gpus)) {
143  devs <- mx.cpu()
144} else {
145  devs <- lapply(unlist(strsplit(args$gpus, ",")), function(i) {
146    mx.gpu(as.integer(i))
147  })
148}
149
150mx.set.seed(0)
151
152model <- mx.model.FeedForward.create(
153  X                  = train,
154  eval.data          = val,
155  ctx                = devs,
156  symbol             = net,
157  num.round          = args$num_round,
158  array.batch.size   = args$batch_size,
159  learning.rate      = args$lr,
160  momentum           = args$mom,
161  eval.metric        = mx.metric.accuracy,
162  initializer        = mx.init.uniform(0.07),
163  batch.end.callback = mx.callback.log.train.metric(100))
164