1""" 2Train on CIFAR-10 with Mixup 3============================ 4 5""" 6 7from __future__ import division 8 9import matplotlib 10matplotlib.use('Agg') 11 12import argparse, time, logging, random, math 13 14import numpy as np 15import mxnet as mx 16 17from mxnet import gluon, nd 18from mxnet import autograd as ag 19from mxnet.gluon import nn 20from mxnet.gluon.data.vision import transforms 21 22import gluoncv as gcv 23gcv.utils.check_version('0.6.0') 24from gluoncv.model_zoo import get_model 25from gluoncv.data import transforms as gcv_transforms 26from gluoncv.utils import makedirs, TrainingHistory 27 28# CLI 29def parse_args(): 30 parser = argparse.ArgumentParser(description='Train a model for image classification.') 31 parser.add_argument('--batch-size', type=int, default=32, 32 help='training batch size per device (CPU/GPU).') 33 parser.add_argument('--num-gpus', type=int, default=0, 34 help='number of gpus to use.') 35 parser.add_argument('--model', type=str, default='resnet', 36 help='model to use. options are resnet and wrn. default is resnet.') 37 parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int, 38 help='number of preprocessing workers') 39 parser.add_argument('--num-epochs', type=int, default=3, 40 help='number of training epochs.') 41 parser.add_argument('--lr', type=float, default=0.1, 42 help='learning rate. default is 0.1.') 43 parser.add_argument('--momentum', type=float, default=0.9, 44 help='momentum value for optimizer, default is 0.9.') 45 parser.add_argument('--wd', type=float, default=0.0001, 46 help='weight decay rate. default is 0.0001.') 47 parser.add_argument('--lr-decay', type=float, default=0.1, 48 help='decay rate of learning rate. default is 0.1.') 49 parser.add_argument('--lr-decay-period', type=int, default=0, 50 help='period in epoch for learning rate decays. default is 0 (has no effect).') 51 parser.add_argument('--lr-decay-epoch', type=str, default='40,60', 52 help='epochs at which learning rate decays. default is 40,60.') 53 parser.add_argument('--drop-rate', type=float, default=0.0, 54 help='dropout rate for wide resnet. default is 0.') 55 parser.add_argument('--mode', type=str, 56 help='mode in which to train the model. options are imperative, hybrid') 57 parser.add_argument('--save-period', type=int, default=10, 58 help='period in epoch of model saving.') 59 parser.add_argument('--save-dir', type=str, default='params', 60 help='directory of saved models') 61 parser.add_argument('--logging-dir', type=str, default='logs', 62 help='directory of training logs') 63 parser.add_argument('--resume-from', type=str, 64 help='resume training from the model') 65 parser.add_argument('--save-plot-dir', type=str, default='.', 66 help='the path to save the history plot') 67 opt = parser.parse_args() 68 return opt 69 70 71def main(): 72 opt = parse_args() 73 74 batch_size = opt.batch_size 75 classes = 10 76 77 num_gpus = opt.num_gpus 78 batch_size *= max(1, num_gpus) 79 context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] 80 num_workers = opt.num_workers 81 82 lr_decay = opt.lr_decay 83 lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf] 84 85 model_name = opt.model 86 if model_name.startswith('cifar_wideresnet'): 87 kwargs = {'classes': classes, 88 'drop_rate': opt.drop_rate} 89 else: 90 kwargs = {'classes': classes} 91 net = get_model(model_name, **kwargs) 92 model_name += '_mixup' 93 if opt.resume_from: 94 net.load_parameters(opt.resume_from, ctx = context) 95 optimizer = 'nag' 96 97 save_period = opt.save_period 98 if opt.save_dir and save_period: 99 save_dir = opt.save_dir 100 makedirs(save_dir) 101 else: 102 save_dir = '' 103 save_period = 0 104 105 plot_name = opt.save_plot_dir 106 107 logging_handlers = [logging.StreamHandler()] 108 if opt.logging_dir: 109 logging_dir = opt.logging_dir 110 makedirs(logging_dir) 111 logging_handlers.append(logging.FileHandler('%s/train_cifar10_%s.log'%(logging_dir, model_name))) 112 113 logging.basicConfig(level=logging.INFO, handlers = logging_handlers) 114 logging.info(opt) 115 116 transform_train = transforms.Compose([ 117 gcv_transforms.RandomCrop(32, pad=4), 118 transforms.RandomFlipLeftRight(), 119 transforms.ToTensor(), 120 transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 121 ]) 122 123 transform_test = transforms.Compose([ 124 transforms.ToTensor(), 125 transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 126 ]) 127 128 def label_transform(label, classes): 129 ind = label.astype('int') 130 res = nd.zeros((ind.shape[0], classes), ctx = label.context) 131 res[nd.arange(ind.shape[0], ctx = label.context), ind] = 1 132 return res 133 134 def test(ctx, val_data): 135 metric = mx.metric.Accuracy() 136 for i, batch in enumerate(val_data): 137 data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) 138 label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) 139 outputs = [net(X) for X in data] 140 metric.update(label, outputs) 141 return metric.get() 142 143 def train(epochs, ctx): 144 if isinstance(ctx, mx.Context): 145 ctx = [ctx] 146 net.initialize(mx.init.Xavier(), ctx=ctx) 147 148 train_data = gluon.data.DataLoader( 149 gluon.data.vision.CIFAR10(train=True).transform_first(transform_train), 150 batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers) 151 152 val_data = gluon.data.DataLoader( 153 gluon.data.vision.CIFAR10(train=False).transform_first(transform_test), 154 batch_size=batch_size, shuffle=False, num_workers=num_workers) 155 156 trainer = gluon.Trainer(net.collect_params(), optimizer, 157 {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum}) 158 metric = mx.metric.Accuracy() 159 train_metric = mx.metric.RMSE() 160 loss_fn = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False) 161 train_history = TrainingHistory(['training-error', 'validation-error']) 162 163 iteration = 0 164 lr_decay_count = 0 165 166 best_val_score = 0 167 168 for epoch in range(epochs): 169 tic = time.time() 170 train_metric.reset() 171 metric.reset() 172 train_loss = 0 173 num_batch = len(train_data) 174 alpha = 1 175 176 if epoch == lr_decay_epoch[lr_decay_count]: 177 trainer.set_learning_rate(trainer.learning_rate*lr_decay) 178 lr_decay_count += 1 179 180 for i, batch in enumerate(train_data): 181 lam = np.random.beta(alpha, alpha) 182 if epoch >= epochs - 20: 183 lam = 1 184 185 data_1 = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) 186 label_1 = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) 187 188 data = [lam*X + (1-lam)*X[::-1] for X in data_1] 189 label = [] 190 for Y in label_1: 191 y1 = label_transform(Y, classes) 192 y2 = label_transform(Y[::-1], classes) 193 label.append(lam*y1 + (1-lam)*y2) 194 195 with ag.record(): 196 output = [net(X) for X in data] 197 loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)] 198 for l in loss: 199 l.backward() 200 trainer.step(batch_size) 201 train_loss += sum([l.sum().asscalar() for l in loss]) 202 203 output_softmax = [nd.SoftmaxActivation(out) for out in output] 204 train_metric.update(label, output_softmax) 205 name, acc = train_metric.get() 206 iteration += 1 207 208 train_loss /= batch_size * num_batch 209 name, acc = train_metric.get() 210 name, val_acc = test(ctx, val_data) 211 train_history.update([acc, 1-val_acc]) 212 train_history.plot(save_path='%s/%s_history.png'%(plot_name, model_name)) 213 214 if val_acc > best_val_score: 215 best_val_score = val_acc 216 net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch)) 217 218 name, val_acc = test(ctx, val_data) 219 logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' % 220 (epoch, acc, val_acc, train_loss, time.time()-tic)) 221 222 if save_period and save_dir and (epoch + 1) % save_period == 0: 223 net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch)) 224 225 if save_period and save_dir: 226 net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1)) 227 228 if opt.mode == 'hybrid': 229 net.hybridize() 230 train(opt.num_epochs, context) 231 232if __name__ == '__main__': 233 main() 234