1import matplotlib 2matplotlib.use('Agg') 3 4import argparse, time, logging 5 6import numpy as np 7import mxnet as mx 8 9from mxnet import gluon, nd 10from mxnet import autograd as ag 11from mxnet.gluon import nn 12from mxnet.gluon.data.vision import transforms 13 14import gluoncv as gcv 15gcv.utils.check_version('0.6.0') 16from gluoncv.model_zoo import get_model 17from gluoncv.utils import makedirs, TrainingHistory 18from gluoncv.data import transforms as gcv_transforms 19 20# CLI 21def parse_args(): 22 parser = argparse.ArgumentParser(description='Train a model for image classification.') 23 parser.add_argument('--batch-size', type=int, default=32, 24 help='training batch size per device (CPU/GPU).') 25 parser.add_argument('--num-gpus', type=int, default=0, 26 help='number of gpus to use.') 27 parser.add_argument('--model', type=str, default='resnet', 28 help='model to use. options are resnet and wrn. default is resnet.') 29 parser.add_argument('-j', '--num-data-workers', dest='num_workers', default=4, type=int, 30 help='number of preprocessing workers') 31 parser.add_argument('--num-epochs', type=int, default=3, 32 help='number of training epochs.') 33 parser.add_argument('--lr', type=float, default=0.1, 34 help='learning rate. default is 0.1.') 35 parser.add_argument('--momentum', type=float, default=0.9, 36 help='momentum value for optimizer, default is 0.9.') 37 parser.add_argument('--wd', type=float, default=0.0001, 38 help='weight decay rate. default is 0.0001.') 39 parser.add_argument('--lr-decay', type=float, default=0.1, 40 help='decay rate of learning rate. default is 0.1.') 41 parser.add_argument('--lr-decay-period', type=int, default=0, 42 help='period in epoch for learning rate decays. default is 0 (has no effect).') 43 parser.add_argument('--lr-decay-epoch', type=str, default='40,60', 44 help='epochs at which learning rate decays. default is 40,60.') 45 parser.add_argument('--drop-rate', type=float, default=0.0, 46 help='dropout rate for wide resnet. default is 0.') 47 parser.add_argument('--mode', type=str, 48 help='mode in which to train the model. options are imperative, hybrid') 49 parser.add_argument('--save-period', type=int, default=10, 50 help='period in epoch of model saving.') 51 parser.add_argument('--save-dir', type=str, default='params', 52 help='directory of saved models') 53 parser.add_argument('--resume-from', type=str, 54 help='resume training from the model') 55 parser.add_argument('--save-plot-dir', type=str, default='.', 56 help='the path to save the history plot') 57 opt = parser.parse_args() 58 return opt 59 60 61def main(): 62 opt = parse_args() 63 64 batch_size = opt.batch_size 65 classes = 10 66 67 num_gpus = opt.num_gpus 68 batch_size *= max(1, num_gpus) 69 context = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()] 70 num_workers = opt.num_workers 71 72 lr_decay = opt.lr_decay 73 lr_decay_epoch = [int(i) for i in opt.lr_decay_epoch.split(',')] + [np.inf] 74 75 model_name = opt.model 76 if model_name.startswith('cifar_wideresnet'): 77 kwargs = {'classes': classes, 78 'drop_rate': opt.drop_rate} 79 else: 80 kwargs = {'classes': classes} 81 net = get_model(model_name, **kwargs) 82 if opt.resume_from: 83 net.load_parameters(opt.resume_from, ctx = context) 84 optimizer = 'nag' 85 86 save_period = opt.save_period 87 if opt.save_dir and save_period: 88 save_dir = opt.save_dir 89 makedirs(save_dir) 90 else: 91 save_dir = '' 92 save_period = 0 93 94 plot_path = opt.save_plot_dir 95 96 logging.basicConfig(level=logging.INFO) 97 logging.info(opt) 98 99 transform_train = transforms.Compose([ 100 gcv_transforms.RandomCrop(32, pad=4), 101 transforms.RandomFlipLeftRight(), 102 transforms.ToTensor(), 103 transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 104 ]) 105 106 transform_test = transforms.Compose([ 107 transforms.ToTensor(), 108 transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) 109 ]) 110 111 def test(ctx, val_data): 112 metric = mx.metric.Accuracy() 113 for i, batch in enumerate(val_data): 114 data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) 115 label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) 116 outputs = [net(X) for X in data] 117 metric.update(label, outputs) 118 return metric.get() 119 120 def train(epochs, ctx): 121 if isinstance(ctx, mx.Context): 122 ctx = [ctx] 123 net.initialize(mx.init.Xavier(), ctx=ctx) 124 125 train_data = gluon.data.DataLoader( 126 gluon.data.vision.CIFAR10(train=True).transform_first(transform_train), 127 batch_size=batch_size, shuffle=True, last_batch='discard', num_workers=num_workers) 128 129 val_data = gluon.data.DataLoader( 130 gluon.data.vision.CIFAR10(train=False).transform_first(transform_test), 131 batch_size=batch_size, shuffle=False, num_workers=num_workers) 132 133 trainer = gluon.Trainer(net.collect_params(), optimizer, 134 {'learning_rate': opt.lr, 'wd': opt.wd, 'momentum': opt.momentum}) 135 metric = mx.metric.Accuracy() 136 train_metric = mx.metric.Accuracy() 137 loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() 138 train_history = TrainingHistory(['training-error', 'validation-error']) 139 140 iteration = 0 141 lr_decay_count = 0 142 143 best_val_score = 0 144 145 for epoch in range(epochs): 146 tic = time.time() 147 train_metric.reset() 148 metric.reset() 149 train_loss = 0 150 num_batch = len(train_data) 151 alpha = 1 152 153 if epoch == lr_decay_epoch[lr_decay_count]: 154 trainer.set_learning_rate(trainer.learning_rate*lr_decay) 155 lr_decay_count += 1 156 157 for i, batch in enumerate(train_data): 158 data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0) 159 label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0) 160 161 with ag.record(): 162 output = [net(X) for X in data] 163 loss = [loss_fn(yhat, y) for yhat, y in zip(output, label)] 164 for l in loss: 165 l.backward() 166 trainer.step(batch_size) 167 train_loss += sum([l.sum().asscalar() for l in loss]) 168 169 train_metric.update(label, output) 170 name, acc = train_metric.get() 171 iteration += 1 172 173 train_loss /= batch_size * num_batch 174 name, acc = train_metric.get() 175 name, val_acc = test(ctx, val_data) 176 train_history.update([1-acc, 1-val_acc]) 177 train_history.plot(save_path='%s/%s_history.png'%(plot_path, model_name)) 178 179 if val_acc > best_val_score: 180 best_val_score = val_acc 181 net.save_parameters('%s/%.4f-cifar-%s-%d-best.params'%(save_dir, best_val_score, model_name, epoch)) 182 183 logging.info('[Epoch %d] train=%f val=%f loss=%f time: %f' % 184 (epoch, acc, val_acc, train_loss, time.time()-tic)) 185 186 if save_period and save_dir and (epoch + 1) % save_period == 0: 187 net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epoch)) 188 189 if save_period and save_dir: 190 net.save_parameters('%s/cifar10-%s-%d.params'%(save_dir, model_name, epochs-1)) 191 192 193 194 if opt.mode == 'hybrid': 195 net.hybridize() 196 train(opt.num_epochs, context) 197 198if __name__ == '__main__': 199 main() 200