1import matplotlib
2matplotlib.use('Agg')
3
4import argparse, time, logging
5
6import numpy as np
7import mxnet as mx
8
9from mxnet import gluon, nd
10from mxnet import autograd as ag
11from mxnet.gluon import nn
12from mxnet.gluon.data.vision import transforms
13
14import gluoncv as gcv
15gcv.utils.check_version('0.6.0')
16from gluoncv.model_zoo import get_model
17from gluoncv.utils import makedirs, TrainingHistory
18from gluoncv.data import transforms as gcv_transforms
19
20# CLI
21def parse_args():
22    parser = argparse.ArgumentParser(description='Train a model for image classification.')
23    parser.add_argument('--batch-size', type=int, default=32,
24                        help='training batch size per device (CPU/GPU).')
25    parser.add_argument('--num-gpus', type=int, default=0,
26                        help='number of gpus to use.')
27    parser.add_argument('--model', type=str, default='resnet',
28                        help='model to use. options are resnet and wrn. default is resnet.')
29    parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int,
30                        help='number of preprocessing workers')
31    parser.add_argument('--num-epochs', type=int, default=3,
32                        help='number of training epochs.')
33    parser.add_argument('--lr', type=float, default=0.1,
34                        help='learning rate. default is 0.1.')
35    parser.add_argument('--momentum', type=float, default=0.9,
36                        help='momentum value for optimizer, default is 0.9.')
37    parser.add_argument('--wd', type=float, default=0.0001,
38                        help='weight decay rate. default is 0.0001.')
39    parser.add_argument('--lr-decay', type=float, default=0.1,
40                        help='decay rate of learning rate. default is 0.1.')
41    parser.add_argument('--lr-decay-period', type=int, default=0,
42                        help='period in epoch for learning rate decays. default is 0 (has no effect).')
43    parser.add_argument('--lr-decay-epoch', type=str, default='40,60',
44                        help='epochs at which learning rate decays. default is 40,60.')
45    parser.add_argument('--drop-rate', type=float, default=0.0,
46                        help='dropout rate for wide resnet. default is 0.')
47    parser.add_argument('--mode', type=str,
48                        help='mode in which to train the model. options are imperative, hybrid')
49    parser.add_argument('--save-period', type=int, default=10,
50                        help='period in epoch of model saving.')
51    parser.add_argument('--save-dir', type=str, default='params',
52                        help='directory of saved models')
53    parser.add_argument('--resume-from', type=str,
54                        help='resume training from the model')
55    parser.add_argument('--save-plot-dir', type=str, default='.',
56                        help='the path to save the history plot')
57    opt = parser.parse_args()
58    return opt
59
60
61def main():
62    opt = parse_args()
63
64    batch_size = opt.batch_size
65    classes = 10
66
67    num_gpus = opt.num_gpus
68    batch_size *= max(1, num_gpus)
69    context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
70    num_workers = opt.num_workers
71
72    lr_decay = opt.lr_decay
73    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]
74
75    model_name = opt.model
76    if model_name.startswith('cifar_wideresnet'):
77        kwargs = {'classes': classes,
78                'drop_rate': opt.drop_rate}
79    else:
80        kwargs = {'classes': classes}
81    net = get_model(model_name, **kwargs)
82    if opt.resume_from:
83        net.load_parameters(opt.resume_from, ctx = context)
84    optimizer = 'nag'
85
86    save_period = opt.save_period
87    if opt.save_dir and save_period:
88        save_dir = opt.save_dir
89        makedirs(save_dir)
90    else:
91        save_dir = ''
92        save_period = 0
93
94    plot_path = opt.save_plot_dir
95
96    logging.basicConfig(level=logging.INFO)
97    logging.info(opt)
98
99    transform_train = transforms.Compose([
100        gcv_transforms.RandomCrop(32, pad=4),
101        transforms.RandomFlipLeftRight(),
102        transforms.ToTensor(),
103        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
104    ])
105
106    transform_test = transforms.Compose([
107        transforms.ToTensor(),
108        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
109    ])
110
111    def test(ctx, val_data):
112        metric = mx.metric.Accuracy()
113        for i, batch in enumerate(val_data):
114            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
115            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
116            outputs = [net(X) for X in data]
117            metric.update(label, outputs)
118        return metric.get()
119
120    def train(epochs, ctx):
121        if isinstance(ctx, mx.Context):
122            ctx = [ctx]
123        net.initialize(mx.init.Xavier(), ctx=ctx)
124
125        train_data = gluon.data.DataLoader(
126            gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
127            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
128
129        val_data = gluon.data.DataLoader(
130            gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
131            batch_size=batch_size, shuffle=False, num_workers=num_workers)
132
133        trainer = gluon.Trainer(net.collect_params(), optimizer,
134                                {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum})
135        metric = mx.metric.Accuracy()
136        train_metric = mx.metric.Accuracy()
137        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
138        train_history = TrainingHistory(['training-error', 'validation-error'])
139
140        iteration = 0
141        lr_decay_count = 0
142
143        best_val_score = 0
144
145        for epoch in range(epochs):
146            tic = time.time()
147            train_metric.reset()
148            metric.reset()
149            train_loss = 0
150            num_batch = len(train_data)
151            alpha = 1
152
153            if epoch == lr_decay_epoch[lr_decay_count]:
154                trainer.set_learning_rate(trainer.learning_rate*lr_decay)
155                lr_decay_count += 1
156
157            for i, batch in enumerate(train_data):
158                data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
159                label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
160
161                with ag.record():
162                    output = [net(X) for X in data]
163                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
164                for l in loss:
165                    l.backward()
166                trainer.step(batch_size)
167                train_loss += sum([l.sum().asscalar() for l in loss])
168
169                train_metric.update(label, output)
170                name, acc = train_metric.get()
171                iteration += 1
172
173            train_loss /= batch_size * num_batch
174            name, acc = train_metric.get()
175            name, val_acc = test(ctx, val_data)
176            train_history.update([1-acc, 1-val_acc])
177            train_history.plot(save_path='%s/%s_history.png'%(plot_path, model_name))
178
179            if val_acc > best_val_score:
180                best_val_score = val_acc
181                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
182
183            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
184                (epoch, acc, val_acc, train_loss, time.time()-tic))
185
186            if save_period and save_dir and (epoch + 1) % save_period == 0:
187                net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch))
188
189        if save_period and save_dir:
190            net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))
191
192
193
194    if opt.mode == 'hybrid':
195        net.hybridize()
196    train(opt.num_epochs, context)
197
198if __name__ == '__main__':
199    main()
200