1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# pylint: skip-file 19import sys 20sys.path.insert(0, '../../python') 21import mxnet as mx 22from mxnet.test_utils import get_mnist_ubyte 23import numpy as np 24import os, pickle, gzip, argparse 25import logging 26 27def get_model(use_gpu): 28 # symbol net 29 data = mx.symbol.Variable('data') 30 conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2)) 31 bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1") 32 act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu") 33 mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max') 34 35 conv2= mx.symbol.Convolution(data = mp1, name='conv2', num_filter=32, kernel=(3,3), stride=(2,2)) 36 bn2 = mx.symbol.BatchNorm(data = conv2, name="bn2") 37 act2 = mx.symbol.Activation(data = bn2, name='relu2', act_type="relu") 38 mp2 = mx.symbol.Pooling(data = act2, name = 'mp2', kernel=(2,2), stride=(2,2), pool_type='max') 39 40 41 fl = mx.symbol.Flatten(data = mp2, name="flatten") 42 fc2 = mx.symbol.FullyConnected(data = fl, name='fc2', num_hidden=10) 43 softmax = mx.symbol.SoftmaxOutput(data = fc2, name = 'sm') 44 45 num_epoch = 1 46 ctx=mx.gpu() if use_gpu else mx.cpu() 47 model = mx.model.FeedForward(softmax, ctx, 48 num_epoch=num_epoch, 49 learning_rate=0.1, wd=0.0001, 50 momentum=0.9) 51 return model 52 53def get_iters(): 54 # check data 55 get_mnist_ubyte() 56 57 batch_size = 100 58 train_dataiter = mx.io.MNISTIter( 59 image="data/train-images-idx3-ubyte", 60 label="data/train-labels-idx1-ubyte", 61 data_shape=(1, 28, 28), 62 label_name='sm_label', 63 batch_size=batch_size, shuffle=True, flat=False, silent=False, seed=10) 64 val_dataiter = mx.io.MNISTIter( 65 image="data/t10k-images-idx3-ubyte", 66 label="data/t10k-labels-idx1-ubyte", 67 data_shape=(1, 28, 28), 68 label_name='sm_label', 69 batch_size=batch_size, shuffle=True, flat=False, silent=False) 70 return train_dataiter, val_dataiter 71 72# run default with unit test framework 73def test_mnist(): 74 iters = get_iters() 75 exec_mnist(get_model(False), iters[0], iters[1]) 76 77def exec_mnist(model, train_dataiter, val_dataiter): 78 # print logging by default 79 logging.basicConfig(level=logging.DEBUG) 80 console = logging.StreamHandler() 81 console.setLevel(logging.DEBUG) 82 logging.getLogger('').addHandler(console) 83 84 model.fit(X=train_dataiter, 85 eval_data=val_dataiter) 86 logging.info('Finish fit...') 87 prob = model.predict(val_dataiter) 88 logging.info('Finish predict...') 89 val_dataiter.reset() 90 y = np.concatenate([batch.label[0].asnumpy() for batch in val_dataiter]).astype('int') 91 py = np.argmax(prob, axis=1) 92 acc1 = float(np.sum(py == y)) / len(y) 93 logging.info('final accuracy = %f', acc1) 94 assert(acc1 > 0.94) 95 96# run as a script 97if __name__ == "__main__": 98 parser = argparse.ArgumentParser() 99 parser.add_argument('--gpu', action='store_true', help='use gpu to train') 100 args = parser.parse_args() 101 iters = get_iters() 102 exec_mnist(get_model(args.gpu), iters[0], iters[1]) 103 104