1"""Train CenterNet""" 2import argparse 3import os 4import logging 5import warnings 6import time 7import numpy as np 8import mxnet as mx 9from mxnet import nd 10from mxnet import gluon 11from mxnet import autograd 12import gluoncv as gcv 13gcv.utils.check_version('0.6.0') 14from gluoncv import data as gdata 15from gluoncv import utils as gutils 16from gluoncv.model_zoo import get_model 17from gluoncv.data.batchify import Tuple, Stack, Pad 18from gluoncv.data.transforms.presets.center_net import CenterNetDefaultTrainTransform 19from gluoncv.data.transforms.presets.center_net import CenterNetDefaultValTransform, get_post_transform 20 21from gluoncv.utils.metrics.voc_detection import VOC07MApMetric 22from gluoncv.utils.metrics.coco_detection import COCODetectionMetric 23from gluoncv.utils.metrics.accuracy import Accuracy 24from gluoncv.utils import LRScheduler, LRSequential 25 26 27def parse_args(): 28 parser = argparse.ArgumentParser(description='Train CenterNet networks.') 29 parser.add_argument('--network', type=str, default='resnet18_v1b', 30 help="Base network name which serves as feature extraction base.") 31 parser.add_argument('--data-shape', type=int, default=512, 32 help="Input data shape, use 300, 512.") 33 parser.add_argument('--batch-size', type=int, default=32, 34 help='Training mini-batch size') 35 parser.add_argument('--dataset', type=str, default='voc', 36 help='Training dataset. Now support voc.') 37 parser.add_argument('--dataset-root', type=str, default='~/.mxnet/datasets/', 38 help='Path of the directory where the dataset is located.') 39 parser.add_argument('--num-workers', '-j', dest='num_workers', type=int, 40 default=4, help='Number of data workers, you can use larger ' 41 'number to accelerate data loading, if you CPU and GPUs are powerful.') 42 parser.add_argument('--gpus', type=str, default='0', 43 help='Training with GPUs, you can specify 1,3 for example.') 44 parser.add_argument('--epochs', type=int, default=140, 45 help='Training epochs.') 46 parser.add_argument('--resume', type=str, default='', 47 help='Resume from previously saved parameters if not None. ' 48 'For example, you can resume from ./ssd_xxx_0123.params') 49 parser.add_argument('--start-epoch', type=int, default=0, 50 help='Starting epoch for resuming, default is 0 for new training.' 51 'You can specify it to 100 for example to start from 100 epoch.') 52 parser.add_argument('--lr', type=float, default=1.25e-4, 53 help='Learning rate, default is 0.000125') 54 parser.add_argument('--lr-decay', type=float, default=0.1, 55 help='decay rate of learning rate. default is 0.1.') 56 parser.add_argument('--lr-decay-epoch', type=str, default='90,120', 57 help='epochs at which learning rate decays. default is 90,120.') 58 parser.add_argument('--lr-mode', type=str, default='step', 59 help='learning rate scheduler mode. options are step, poly and cosine.') 60 parser.add_argument('--warmup-lr', type=float, default=0.0, 61 help='starting warmup learning rate. default is 0.0.') 62 parser.add_argument('--warmup-epochs', type=int, default=0, 63 help='number of warmup epochs.') 64 parser.add_argument('--momentum', type=float, default=0.9, 65 help='SGD momentum, default is 0.9') 66 parser.add_argument('--wd', type=float, default=0.0001, 67 help='Weight decay, default is 1e-4') 68 parser.add_argument('--log-interval', type=int, default=100, 69 help='Logging mini-batch interval. Default is 100.') 70 parser.add_argument('--num-samples', type=int, default=-1, 71 help='Training images. Use -1 to automatically get the number.') 72 parser.add_argument('--save-prefix', type=str, default='', 73 help='Saving parameter prefix') 74 parser.add_argument('--save-interval', type=int, default=10, 75 help='Saving parameters epoch interval, best model will always be saved.') 76 parser.add_argument('--val-interval', type=int, default=1, 77 help='Epoch interval for validation, increase the number will reduce the ' 78 'training time if validation is slow.') 79 parser.add_argument('--seed', type=int, default=233, 80 help='Random seed to be fixed.') 81 parser.add_argument('--wh-weight', type=float, default=0.1, 82 help='Loss weight for width/height') 83 parser.add_argument('--center-reg-weight', type=float, default=1.0, 84 help='Center regression loss weight') 85 parser.add_argument('--flip-validation', action='store_true', 86 help='flip data augmentation in validation.') 87 88 args = parser.parse_args() 89 return args 90 91def get_dataset(dataset, args): 92 if dataset.lower() == 'voc': 93 train_dataset = gdata.VOCDetection( 94 splits=[(2007, 'trainval'), (2012, 'trainval')]) 95 val_dataset = gdata.VOCDetection( 96 splits=[(2007, 'test')]) 97 val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes) 98 elif dataset.lower() == 'coco': 99 train_dataset = gdata.COCODetection(root=args.dataset_root + "/coco", splits='instances_train2017') 100 val_dataset = gdata.COCODetection(root=args.dataset_root + "/coco", splits='instances_val2017', skip_empty=False) 101 val_metric = COCODetectionMetric( 102 val_dataset, args.save_prefix + '_eval', cleanup=True, 103 data_shape=(args.data_shape, args.data_shape), post_affine=get_post_transform) 104 # coco validation is slow, consider increase the validation interval 105 if args.val_interval == 1: 106 args.val_interval = 10 107 else: 108 raise NotImplementedError('Dataset: {} not implemented.'.format(dataset)) 109 if args.num_samples < 0: 110 args.num_samples = len(train_dataset) 111 return train_dataset, val_dataset, val_metric 112 113def get_dataloader(net, train_dataset, val_dataset, data_shape, batch_size, num_workers, ctx): 114 """Get dataloader.""" 115 width, height = data_shape, data_shape 116 num_class = len(train_dataset.classes) 117 batchify_fn = Tuple([Stack() for _ in range(6)]) # stack image, cls_targets, box_targets 118 train_loader = gluon.data.DataLoader( 119 train_dataset.transform(CenterNetDefaultTrainTransform( 120 width, height, num_class=num_class, scale_factor=net.scale)), 121 batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers) 122 val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1)) 123 val_loader = gluon.data.DataLoader( 124 val_dataset.transform(CenterNetDefaultValTransform(width, height)), 125 batch_size, False, batchify_fn=val_batchify_fn, last_batch='keep', num_workers=num_workers) 126 return train_loader, val_loader 127 128def save_params(net, best_map, current_map, epoch, save_interval, prefix): 129 current_map = float(current_map) 130 if current_map > best_map[0]: 131 best_map[0] = current_map 132 net.save_parameters('{:s}_best.params'.format(prefix, epoch, current_map)) 133 with open(prefix+'_best_map.log', 'a') as f: 134 f.write('{:04d}:\t{:.4f}\n'.format(epoch, current_map)) 135 if save_interval and epoch % save_interval == 0: 136 net.save_parameters('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map)) 137 138def validate(net, val_data, ctx, eval_metric, flip_test=False): 139 """Test on validation dataset.""" 140 eval_metric.reset() 141 net.flip_test = flip_test 142 mx.nd.waitall() 143 net.hybridize() 144 for batch in val_data: 145 data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False) 146 label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False) 147 det_bboxes = [] 148 det_ids = [] 149 det_scores = [] 150 gt_bboxes = [] 151 gt_ids = [] 152 gt_difficults = [] 153 for x, y in zip(data, label): 154 # get prediction results 155 ids, scores, bboxes = net(x) 156 det_ids.append(ids) 157 det_scores.append(scores) 158 # clip to image size 159 det_bboxes.append(bboxes.clip(0, batch[0].shape[2])) 160 # split ground truths 161 gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5)) 162 gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4)) 163 gt_difficults.append(y.slice_axis(axis=-1, begin=5, end=6) if y.shape[-1] > 5 else None) 164 165 # update metric 166 eval_metric.update(det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults) 167 return eval_metric.get() 168 169def train(net, train_data, val_data, eval_metric, ctx, args): 170 """Training pipeline""" 171 net.collect_params().reset_ctx(ctx) 172 # lr decay policy 173 lr_decay = float(args.lr_decay) 174 lr_steps = sorted([int(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()]) 175 lr_decay_epoch = [e - args.warmup_epochs for e in lr_steps] 176 num_batches = args.num_samples // args.batch_size 177 lr_scheduler = LRSequential([ 178 LRScheduler('linear', base_lr=0, target_lr=args.lr, 179 nepochs=args.warmup_epochs, iters_per_epoch=num_batches), 180 LRScheduler(args.lr_mode, base_lr=args.lr, 181 nepochs=args.epochs - args.warmup_epochs, 182 iters_per_epoch=num_batches, 183 step_epoch=lr_decay_epoch, 184 step_factor=args.lr_decay, power=2), 185 ]) 186 187 for k, v in net.collect_params('.*bias').items(): 188 v.wd_mult = 0.0 189 trainer = gluon.Trainer( 190 net.collect_params(), 'adam', 191 {'learning_rate': args.lr, 'wd': args.wd, 192 'lr_scheduler': lr_scheduler}) 193 194 heatmap_loss = gcv.loss.HeatmapFocalLoss(from_logits=True) 195 wh_loss = gcv.loss.MaskedL1Loss(weight=args.wh_weight) 196 center_reg_loss = gcv.loss.MaskedL1Loss(weight=args.center_reg_weight) 197 heatmap_loss_metric = mx.metric.Loss('HeatmapFocal') 198 wh_metric = mx.metric.Loss('WHL1') 199 center_reg_metric = mx.metric.Loss('CenterRegL1') 200 201 # set up logger 202 logging.basicConfig() 203 logger = logging.getLogger() 204 logger.setLevel(logging.INFO) 205 log_file_path = args.save_prefix + '_train.log' 206 log_dir = os.path.dirname(log_file_path) 207 if log_dir and not os.path.exists(log_dir): 208 os.makedirs(log_dir) 209 fh = logging.FileHandler(log_file_path) 210 logger.addHandler(fh) 211 logger.info(args) 212 logger.info('Start training from [Epoch {}]'.format(args.start_epoch)) 213 best_map = [0] 214 215 for epoch in range(args.start_epoch, args.epochs): 216 wh_metric.reset() 217 center_reg_metric.reset() 218 tic = time.time() 219 btic = time.time() 220 net.hybridize() 221 222 for i, batch in enumerate(train_data): 223 split_data = [gluon.utils.split_and_load(batch[ind], ctx_list=ctx, batch_axis=0) for ind in range(6)] 224 data, heatmap_targets, wh_targets, wh_masks, center_reg_targets, center_reg_masks = split_data 225 batch_size = args.batch_size 226 with autograd.record(): 227 sum_losses = [] 228 heatmap_losses = [] 229 wh_losses = [] 230 center_reg_losses = [] 231 wh_preds = [] 232 center_reg_preds = [] 233 for x, heatmap_target, wh_target, wh_mask, center_reg_target, center_reg_mask in zip(*split_data): 234 heatmap_pred, wh_pred, center_reg_pred = net(x) 235 wh_preds.append(wh_pred) 236 center_reg_preds.append(center_reg_pred) 237 wh_losses.append(wh_loss(wh_pred, wh_target, wh_mask)) 238 center_reg_losses.append(center_reg_loss(center_reg_pred, center_reg_target, center_reg_mask)) 239 heatmap_losses.append(heatmap_loss(heatmap_pred, heatmap_target)) 240 curr_loss = heatmap_losses[-1]+ wh_losses[-1] + center_reg_losses[-1] 241 sum_losses.append(curr_loss) 242 autograd.backward(sum_losses) 243 trainer.step(len(sum_losses)) # step with # gpus 244 245 heatmap_loss_metric.update(0, heatmap_losses) 246 wh_metric.update(0, wh_losses) 247 center_reg_metric.update(0, center_reg_losses) 248 if args.log_interval and not (i + 1) % args.log_interval: 249 name2, loss2 = wh_metric.get() 250 name3, loss3 = center_reg_metric.get() 251 name4, loss4 = heatmap_loss_metric.get() 252 logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, LR={}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format( 253 epoch, i, batch_size/(time.time()-btic), trainer.learning_rate, name2, loss2, name3, loss3, name4, loss4)) 254 btic = time.time() 255 256 name2, loss2 = wh_metric.get() 257 name3, loss3 = center_reg_metric.get() 258 name4, loss4 = heatmap_loss_metric.get() 259 logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format( 260 epoch, (time.time()-tic), name2, loss2, name3, loss3, name4, loss4)) 261 if (epoch % args.val_interval == 0) or (args.save_interval and epoch % args.save_interval == 0) or (epoch == args.epochs - 1): 262 # consider reduce the frequency of validation to save time 263 map_name, mean_ap = validate(net, val_data, ctx, eval_metric, flip_test=args.flip_validation) 264 val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)]) 265 logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg)) 266 current_map = float(mean_ap[-1]) 267 else: 268 current_map = 0. 269 save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix) 270 271if __name__ == '__main__': 272 args = parse_args() 273 274 # fix seed for mxnet, numpy and python builtin random generator. 275 gutils.random.seed(args.seed) 276 277 # training contexts 278 ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()] 279 ctx = ctx if ctx else [mx.cpu()] 280 281 # network 282 net_name = '_'.join(('center_net', args.network, args.dataset)) 283 args.save_prefix += net_name 284 net = get_model(net_name, pretrained_base=True, norm_layer=gluon.nn.BatchNorm) 285 if args.resume.strip(): 286 net.load_parameters(args.resume.strip()) 287 else: 288 with warnings.catch_warnings(record=True) as w: 289 warnings.simplefilter("always") 290 net.initialize() 291 # needed for net to be first gpu when using AMP 292 net.collect_params().reset_ctx(ctx[0]) 293 294 # training data 295 train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args) 296 batch_size = args.batch_size 297 train_data, val_data = get_dataloader( 298 net, train_dataset, val_dataset, args.data_shape, batch_size, args.num_workers, ctx[0]) 299 300 301 # training 302 train(net, train_data, val_data, eval_metric, ctx, args) 303