1"""
2Train on CIFAR-10 with Mixup
3============================
4
5"""
6
7from __future__ import division
8
9import matplotlib
10matplotlib.use('Agg')
11
12import argparse, time, logging, random, math
13
14import numpy as np
15import mxnet as mx
16
17from mxnet import gluon, nd
18from mxnet import autograd as ag
19from mxnet.gluon import nn
20from mxnet.gluon.data.vision import transforms
21
22import gluoncv as gcv
23gcv.utils.check_version('0.6.0')
24from gluoncv.model_zoo import get_model
25from gluoncv.data import transforms as gcv_transforms
26from gluoncv.utils import makedirs, TrainingHistory
27
28# CLI
29def parse_args():
30    parser = argparse.ArgumentParser(description='Train a model for image classification.')
31    parser.add_argument('--batch-size', type=int, default=32,
32                        help='training batch size per device (CPU/GPU).')
33    parser.add_argument('--num-gpus', type=int, default=0,
34                        help='number of gpus to use.')
35    parser.add_argument('--model', type=str, default='resnet',
36                        help='model to use. options are resnet and wrn. default is resnet.')
37    parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int,
38                        help='number of preprocessing workers')
39    parser.add_argument('--num-epochs', type=int, default=3,
40                        help='number of training epochs.')
41    parser.add_argument('--lr', type=float, default=0.1,
42                        help='learning rate. default is 0.1.')
43    parser.add_argument('--momentum', type=float, default=0.9,
44                        help='momentum value for optimizer, default is 0.9.')
45    parser.add_argument('--wd', type=float, default=0.0001,
46                        help='weight decay rate. default is 0.0001.')
47    parser.add_argument('--lr-decay', type=float, default=0.1,
48                        help='decay rate of learning rate. default is 0.1.')
49    parser.add_argument('--lr-decay-period', type=int, default=0,
50                        help='period in epoch for learning rate decays. default is 0 (has no effect).')
51    parser.add_argument('--lr-decay-epoch', type=str, default='40,60',
52                        help='epochs at which learning rate decays. default is 40,60.')
53    parser.add_argument('--drop-rate', type=float, default=0.0,
54                        help='dropout rate for wide resnet. default is 0.')
55    parser.add_argument('--mode', type=str,
56                        help='mode in which to train the model. options are imperative, hybrid')
57    parser.add_argument('--save-period', type=int, default=10,
58                        help='period in epoch of model saving.')
59    parser.add_argument('--save-dir', type=str, default='params',
60                        help='directory of saved models')
61    parser.add_argument('--logging-dir', type=str, default='logs',
62                        help='directory of training logs')
63    parser.add_argument('--resume-from', type=str,
64                        help='resume training from the model')
65    parser.add_argument('--save-plot-dir', type=str, default='.',
66                        help='the path to save the history plot')
67    opt = parser.parse_args()
68    return opt
69
70
71def main():
72    opt = parse_args()
73
74    batch_size = opt.batch_size
75    classes = 10
76
77    num_gpus = opt.num_gpus
78    batch_size *= max(1, num_gpus)
79    context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
80    num_workers = opt.num_workers
81
82    lr_decay = opt.lr_decay
83    lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf]
84
85    model_name = opt.model
86    if model_name.startswith('cifar_wideresnet'):
87        kwargs = {'classes': classes,
88                'drop_rate': opt.drop_rate}
89    else:
90        kwargs = {'classes': classes}
91    net = get_model(model_name, **kwargs)
92    model_name += '_mixup'
93    if opt.resume_from:
94        net.load_parameters(opt.resume_from, ctx = context)
95    optimizer = 'nag'
96
97    save_period = opt.save_period
98    if opt.save_dir and save_period:
99        save_dir = opt.save_dir
100        makedirs(save_dir)
101    else:
102        save_dir = ''
103        save_period = 0
104
105    plot_name = opt.save_plot_dir
106
107    logging_handlers = [logging.StreamHandler()]
108    if opt.logging_dir:
109        logging_dir = opt.logging_dir
110        makedirs(logging_dir)
111        logging_handlers.append(logging.FileHandler('%s/train_cifar10_%s.log'%(logging_dir, model_name)))
112
113    logging.basicConfig(level=logging.INFO, handlers = logging_handlers)
114    logging.info(opt)
115
116    transform_train = transforms.Compose([
117        gcv_transforms.RandomCrop(32, pad=4),
118        transforms.RandomFlipLeftRight(),
119        transforms.ToTensor(),
120        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
121    ])
122
123    transform_test = transforms.Compose([
124        transforms.ToTensor(),
125        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
126    ])
127
128    def label_transform(label, classes):
129        ind = label.astype('int')
130        res = nd.zeros((ind.shape[0], classes), ctx = label.context)
131        res[nd.arange(ind.shape[0], ctx = label.context), ind] = 1
132        return res
133
134    def test(ctx, val_data):
135        metric = mx.metric.Accuracy()
136        for i, batch in enumerate(val_data):
137            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
138            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
139            outputs = [net(X) for X in data]
140            metric.update(label, outputs)
141        return metric.get()
142
143    def train(epochs, ctx):
144        if isinstance(ctx, mx.Context):
145            ctx = [ctx]
146        net.initialize(mx.init.Xavier(), ctx=ctx)
147
148        train_data = gluon.data.DataLoader(
149            gluon.data.vision.CIFAR10(train=True).transform_first(transform_train),
150            batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers)
151
152        val_data = gluon.data.DataLoader(
153            gluon.data.vision.CIFAR10(train=False).transform_first(transform_test),
154            batch_size=batch_size, shuffle=False, num_workers=num_workers)
155
156        trainer = gluon.Trainer(net.collect_params(), optimizer,
157                                {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum})
158        metric = mx.metric.Accuracy()
159        train_metric = mx.metric.RMSE()
160        loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)
161        train_history = TrainingHistory(['training-error', 'validation-error'])
162
163        iteration = 0
164        lr_decay_count = 0
165
166        best_val_score = 0
167
168        for epoch in range(epochs):
169            tic = time.time()
170            train_metric.reset()
171            metric.reset()
172            train_loss = 0
173            num_batch = len(train_data)
174            alpha = 1
175
176            if epoch == lr_decay_epoch[lr_decay_count]:
177                trainer.set_learning_rate(trainer.learning_rate*lr_decay)
178                lr_decay_count += 1
179
180            for i, batch in enumerate(train_data):
181                lam = np.random.beta(alpha, alpha)
182                if epoch >= epochs - 20:
183                    lam = 1
184
185                data_1 = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0)
186                label_1 = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0)
187
188                data = [lam*X + (1-lam)*X[::-1] for X in data_1]
189                label = []
190                for Y in label_1:
191                    y1 = label_transform(Y, classes)
192                    y2 = label_transform(Y[::-1], classes)
193                    label.append(lam*y1 + (1-lam)*y2)
194
195                with ag.record():
196                    output = [net(X) for X in data]
197                    loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)]
198                for l in loss:
199                    l.backward()
200                trainer.step(batch_size)
201                train_loss += sum([l.sum().asscalar() for l in loss])
202
203                output_softmax = [nd.SoftmaxActivation(out) for out in output]
204                train_metric.update(label, output_softmax)
205                name, acc = train_metric.get()
206                iteration += 1
207
208            train_loss /= batch_size * num_batch
209            name, acc = train_metric.get()
210            name, val_acc = test(ctx, val_data)
211            train_history.update([acc, 1-val_acc])
212            train_history.plot(save_path='%s/%s_history.png'%(plot_name, model_name))
213
214            if val_acc > best_val_score:
215                best_val_score = val_acc
216                net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch))
217
218            name, val_acc = test(ctx, val_data)
219            logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' %
220                (epoch, acc, val_acc, train_loss, time.time()-tic))
221
222            if save_period and save_dir and (epoch + 1) % save_period == 0:
223                net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch))
224
225        if save_period and save_dir:
226            net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1))
227
228    if opt.mode == 'hybrid':
229        net.hybridize()
230    train(opt.num_epochs, context)
231
232if __name__ == '__main__':
233    main()
234