1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# pylint: skip-file
19import sys
20sys.path.insert(0, '../../python')
21import mxnet as mx
22from mxnet.test_utils import get_mnist_ubyte
23import numpy as np
24import os, pickle, gzip, argparse
25import logging
26
27def get_model(use_gpu):
28    # symbol net
29    data = mx.symbol.Variable('data')
30    conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2))
31    bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1")
32    act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu")
33    mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max')
34
35    conv2= mx.symbol.Convolution(data = mp1, name='conv2', num_filter=32, kernel=(3,3), stride=(2,2))
36    bn2 = mx.symbol.BatchNorm(data = conv2, name="bn2")
37    act2 = mx.symbol.Activation(data = bn2, name='relu2', act_type="relu")
38    mp2 = mx.symbol.Pooling(data = act2, name = 'mp2', kernel=(2,2), stride=(2,2), pool_type='max')
39
40
41    fl = mx.symbol.Flatten(data = mp2, name="flatten")
42    fc2 = mx.symbol.FullyConnected(data = fl, name='fc2', num_hidden=10)
43    softmax = mx.symbol.SoftmaxOutput(data = fc2, name = 'sm')
44
45    num_epoch = 1
46    ctx=mx.gpu() if use_gpu else mx.cpu()
47    model = mx.model.FeedForward(softmax, ctx,
48                                 num_epoch=num_epoch,
49                                 learning_rate=0.1, wd=0.0001,
50                                 momentum=0.9)
51    return model
52
53def get_iters():
54    # check data
55    get_mnist_ubyte()
56
57    batch_size = 100
58    train_dataiter = mx.io.MNISTIter(
59            image="data/train-images-idx3-ubyte",
60            label="data/train-labels-idx1-ubyte",
61            data_shape=(1, 28, 28),
62            label_name='sm_label',
63            batch_size=batch_size, shuffle=True, flat=False, silent=False, seed=10)
64    val_dataiter = mx.io.MNISTIter(
65            image="data/t10k-images-idx3-ubyte",
66            label="data/t10k-labels-idx1-ubyte",
67            data_shape=(1, 28, 28),
68            label_name='sm_label',
69            batch_size=batch_size, shuffle=True, flat=False, silent=False)
70    return  train_dataiter, val_dataiter
71
72# run default with unit test framework
73def test_mnist():
74    iters = get_iters()
75    exec_mnist(get_model(False), iters[0], iters[1])
76
77def exec_mnist(model, train_dataiter, val_dataiter):
78    # print logging by default
79    logging.basicConfig(level=logging.DEBUG)
80    console = logging.StreamHandler()
81    console.setLevel(logging.DEBUG)
82    logging.getLogger('').addHandler(console)
83
84    model.fit(X=train_dataiter,
85              eval_data=val_dataiter)
86    logging.info('Finish fit...')
87    prob = model.predict(val_dataiter)
88    logging.info('Finish predict...')
89    val_dataiter.reset()
90    y = np.concatenate([batch.label[0].asnumpy() for batch in val_dataiter]).astype('int')
91    py = np.argmax(prob, axis=1)
92    acc1 = float(np.sum(py == y)) / len(y)
93    logging.info('final accuracy = %f', acc1)
94    assert(acc1 > 0.94)
95
96# run as a script
97if __name__ == "__main__":
98    parser = argparse.ArgumentParser()
99    parser.add_argument('--gpu', action='store_true', help='use gpu to train')
100    args = parser.parse_args()
101    iters = get_iters()
102    exec_mnist(get_model(args.gpu), iters[0], iters[1])
103
104