1from __future__ import division 2from __future__ import print_function 3 4import os 5# disable autotune 6os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0' 7import argparse 8import glob 9import logging 10logging.basicConfig(level=logging.INFO) 11import time 12import numpy as np 13import mxnet as mx 14from tqdm import tqdm 15from mxnet import nd 16from mxnet import gluon 17import gluoncv as gcv 18gcv.utils.check_version('0.6.0') 19from gluoncv import data as gdata 20from gluoncv.data import batchify 21from gluoncv.data.transforms.presets.rcnn import FasterRCNNDefaultValTransform 22from gluoncv.utils.metrics.voc_detection import VOC07MApMetric 23from gluoncv.utils.metrics.coco_detection import COCODetectionMetric 24 25def parse_args(): 26 parser = argparse.ArgumentParser(description='Validate Faster-RCNN networks.') 27 parser.add_argument('--network', type=str, default='resnet50_v1b', 28 help="Base feature extraction network name") 29 parser.add_argument('--dataset', type=str, default='voc', 30 help='Training dataset.') 31 parser.add_argument('--num-workers', '-j', dest='num_workers', type=int, 32 default=4, help='Number of data workers') 33 parser.add_argument('--gpus', type=str, default='0', 34 help='Training with GPUs, you can specify 1,3 for example.') 35 parser.add_argument('--pretrained', type=str, default='True', 36 help='Load weights from previously saved parameters.') 37 parser.add_argument('--save-prefix', type=str, default='', 38 help='Saving parameter prefix') 39 parser.add_argument('--save-json', action='store_true', 40 help='Save coco output json') 41 parser.add_argument('--eval-all', action='store_true', 42 help='Eval all models begins with save prefix. Use with pretrained.') 43 parser.add_argument('--norm-layer', type=str, default=None, 44 help='Type of normalization layer to use. ' 45 'If set to None, backbone normalization layer will be fixed,' 46 ' and no normalization layer will be used. ' 47 'Currently supports \'bn\', and None, default is None') 48 parser.add_argument('--use-fpn', action='store_true', 49 help='Whether to use feature pyramid network.') 50 args = parser.parse_args() 51 return args 52 53def get_dataset(dataset, args): 54 if dataset.lower() == 'voc': 55 val_dataset = gdata.VOCDetection( 56 splits=[(2007, 'test')]) 57 val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes) 58 elif dataset.lower() == 'coco': 59 val_dataset = gdata.COCODetection(splits='instances_val2017', skip_empty=False) 60 val_metric = COCODetectionMetric(val_dataset, args.save_prefix + '_eval', 61 cleanup=not args.save_json) 62 else: 63 raise NotImplementedError('Dataset: {} not implemented.'.format(dataset)) 64 return val_dataset, val_metric 65 66def get_dataloader(net, val_dataset, batch_size, num_workers): 67 """Get dataloader.""" 68 val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(3)]) 69 val_loader = mx.gluon.data.DataLoader( 70 val_dataset.transform(FasterRCNNDefaultValTransform(net.short, net.max_size)), 71 batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers) 72 return val_loader 73 74def split_and_load(batch, ctx_list): 75 """Split data to 1 batch each device.""" 76 num_ctx = len(ctx_list) 77 new_batch = [] 78 for i, data in enumerate(batch): 79 new_data = [x.as_in_context(ctx) for x, ctx in zip(data, ctx_list)] 80 new_batch.append(new_data) 81 return new_batch 82 83def validate(net, val_data, ctx, eval_metric, size): 84 """Test on validation dataset.""" 85 clipper = gcv.nn.bbox.BBoxClipToImage() 86 eval_metric.reset() 87 net.hybridize(static_alloc=True) 88 with tqdm(total=size) as pbar: 89 for ib, batch in enumerate(val_data): 90 batch = split_and_load(batch, ctx_list=ctx) 91 det_bboxes = [] 92 det_ids = [] 93 det_scores = [] 94 gt_bboxes = [] 95 gt_ids = [] 96 gt_difficults = [] 97 for x, y, im_scale in zip(*batch): 98 # get prediction results 99 ids, scores, bboxes = net(x) 100 det_ids.append(ids) 101 det_scores.append(scores) 102 # clip to image size 103 det_bboxes.append(clipper(bboxes, x)) 104 # rescale to original resolution 105 im_scale = im_scale.reshape((-1)).asscalar() 106 det_bboxes[-1] *= im_scale 107 # split ground truths 108 gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5)) 109 gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4)) 110 gt_bboxes[-1] *= im_scale 111 gt_difficults.append(y.slice_axis(axis=-1, begin=5, end=6) if y.shape[-1] > 5 else None) 112 # update metric 113 for det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff in zip(det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults): 114 eval_metric.update(det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff) 115 pbar.update(len(ctx)) 116 return eval_metric.get() 117 118if __name__ == '__main__': 119 args = parse_args() 120 121 # contexts 122 ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()] 123 ctx = ctx if ctx else [mx.cpu()] 124 args.batch_size = len(ctx) # 1 batch per device 125 126 # network 127 kwargs = {} 128 module_list = [] 129 if args.use_fpn: 130 module_list.append('fpn') 131 if args.norm_layer is not None: 132 module_list.append(args.norm_layer) 133 if args.norm_layer == 'bn': 134 kwargs['num_devices'] = len(args.gpus.split(',')) 135 net_name = '_'.join(('faster_rcnn', *module_list, args.network, args.dataset)) 136 args.save_prefix += net_name 137 if args.pretrained.lower() in ['true', '1', 'yes', 't']: 138 net = gcv.model_zoo.get_model(net_name, pretrained=True, **kwargs) 139 else: 140 net = gcv.model_zoo.get_model(net_name, pretrained=False, **kwargs) 141 net.load_parameters(args.pretrained.strip(), cast_dtype=True) 142 net.collect_params().reset_ctx(ctx) 143 144 # validation data 145 val_dataset, eval_metric = get_dataset(args.dataset, args) 146 val_data = get_dataloader( 147 net, val_dataset, args.batch_size, args.num_workers) 148 149 # validation 150 if not args.eval_all: 151 names, values = validate(net, val_data, ctx, eval_metric, len(val_dataset)) 152 for k, v in zip(names, values): 153 print(k, v) 154 else: 155 saved_models = glob.glob(args.save_prefix + '*.params') 156 for epoch, saved_model in enumerate(sorted(saved_models)): 157 print('[Epoch {}] Validating from {}'.format(epoch, saved_model)) 158 net.load_parameters(saved_model) 159 net.collect_params().reset_ctx(ctx) 160 map_name, mean_ap = validate(net, val_data, ctx, eval_metric, len(val_dataset)) 161 val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)]) 162 print('[Epoch {}] Validation: \n{}'.format(epoch, val_msg)) 163 current_map = float(mean_ap[-1]) 164 with open(args.save_prefix+'_best_map.log', 'a') as f: 165 f.write('\n{:04d}:\t{:.4f}'.format(epoch, current_map)) 166