1from __future__ import division 2 3import os 4 5# disable autotune 6os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0' 7import argparse 8import glob 9import logging 10 11logging.basicConfig(level=logging.INFO) 12import numpy as np 13import mxnet as mx 14from tqdm import tqdm 15import gluoncv as gcv 16gcv.utils.check_version('0.6.0') 17from gluoncv import data as gdata 18from gluoncv.data import batchify 19from gluoncv.data.transforms.presets.rcnn import MaskRCNNDefaultValTransform 20from gluoncv.utils.metrics.coco_instance import COCOInstanceMetric 21 22 23def parse_args(): 24 parser = argparse.ArgumentParser(description='Validate Mask RCNN networks.') 25 parser.add_argument('--network', type=str, default='resnet50_v1b', 26 help="Base feature extraction network name") 27 parser.add_argument('--dataset', type=str, default='coco', 28 help='Training dataset.') 29 parser.add_argument('--num-workers', '-j', dest='num_workers', type=int, 30 default=4, help='Number of data workers') 31 parser.add_argument('--gpus', type=str, default='0', 32 help='Training with GPUs, you can specify 1,3 for example.') 33 parser.add_argument('--pretrained', type=str, default='True', 34 help='Load weights from previously saved parameters.') 35 parser.add_argument('--save-prefix', type=str, default='', 36 help='Saving parameter prefix') 37 parser.add_argument('--save-json', action='store_true', 38 help='Save coco output json') 39 parser.add_argument('--eval-all', action='store_true', 40 help='Eval all models begins with save prefix. Use with pretrained.') 41 parser.add_argument('--use-fpn', action='store_true', 42 help='Whether to load model with feature pyramid network.') 43 args = parser.parse_args() 44 return args 45 46 47def get_dataset(dataset, args): 48 if dataset.lower() == 'coco': 49 val_dataset = gdata.COCOInstance(splits='instances_val2017', skip_empty=False) 50 val_metric = COCOInstanceMetric(val_dataset, args.save_prefix + '_eval', 51 cleanup=not args.save_json) 52 else: 53 raise NotImplementedError('Dataset: {} not implemented.'.format(dataset)) 54 return val_dataset, val_metric 55 56 57def get_dataloader(net, val_dataset, batch_size, num_workers): 58 """Get dataloader.""" 59 val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(2)]) 60 val_loader = mx.gluon.data.DataLoader( 61 val_dataset.transform(MaskRCNNDefaultValTransform(net.short, net.max_size)), 62 batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers) 63 return val_loader 64 65 66def split_and_load(batch, ctx_list): 67 """Split data to 1 batch each device.""" 68 num_ctx = len(ctx_list) 69 new_batch = [] 70 for i, data in enumerate(batch): 71 new_data = [x.as_in_context(ctx) for x, ctx in zip(data, ctx_list)] 72 new_batch.append(new_data) 73 return new_batch 74 75 76def validate(net, val_data, ctx, eval_metric, size): 77 """Test on validation dataset.""" 78 clipper = gcv.nn.bbox.BBoxClipToImage() 79 eval_metric.reset() 80 net.hybridize(static_alloc=True) 81 with tqdm(total=size) as pbar: 82 for ib, batch in enumerate(val_data): 83 batch = split_and_load(batch, ctx_list=ctx) 84 det_bboxes = [] 85 det_ids = [] 86 det_scores = [] 87 det_masks = [] 88 det_infos = [] 89 for x, im_info in zip(*batch): 90 # get prediction results 91 ids, scores, bboxes, masks = net(x) 92 det_bboxes.append(clipper(bboxes, x)) 93 det_ids.append(ids) 94 det_scores.append(scores) 95 det_masks.append(masks) 96 det_infos.append(im_info) 97 # update metric 98 for det_bbox, det_id, det_score, det_mask, det_info in \ 99 zip(det_bboxes, det_ids, det_scores, det_masks, det_infos): 100 for i in range(det_info.shape[0]): 101 # numpy everything 102 det_bbox = det_bbox[i].asnumpy() 103 det_id = det_id[i].asnumpy() 104 det_score = det_score[i].asnumpy() 105 det_mask = det_mask[i].asnumpy() 106 det_info = det_info[i].asnumpy() 107 # filter by conf threshold 108 im_height, im_width, im_scale = det_info 109 valid = np.where(((det_id >= 0) & (det_score >= 0.001)))[0] 110 det_id = det_id[valid] 111 det_score = det_score[valid] 112 det_bbox = det_bbox[valid] / im_scale 113 det_mask = det_mask[valid] 114 # fill full mask 115 im_height, im_width = int(round(im_height / im_scale)), int( 116 round(im_width / im_scale)) 117 full_masks = gcv.data.transforms.mask.fill(det_mask, det_bbox, 118 (im_width, im_height), 119 fast_fill=False) 120 eval_metric.update(det_bbox, det_id, det_score, full_masks) 121 pbar.update(len(ctx)) 122 return eval_metric.get() 123 124 125if __name__ == '__main__': 126 args = parse_args() 127 128 # training contexts 129 ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()] 130 ctx = ctx if ctx else [mx.cpu()] 131 args.batch_size = len(ctx) # 1 batch per device 132 133 # network 134 module_list = [] 135 if args.use_fpn: 136 module_list.append('fpn') 137 net_name = '_'.join(('mask_rcnn', *module_list, args.network, args.dataset)) 138 args.save_prefix += net_name 139 if args.pretrained.lower() in ['true', '1', 'yes', 't']: 140 net = gcv.model_zoo.get_model(net_name, pretrained=True) 141 else: 142 net = gcv.model_zoo.get_model(net_name, pretrained=False) 143 net.load_parameters(args.pretrained.strip(), cast_dtype=True) 144 net.collect_params().reset_ctx(ctx) 145 146 # training data 147 val_dataset, eval_metric = get_dataset(args.dataset, args) 148 val_data = get_dataloader( 149 net, val_dataset, args.batch_size, args.num_workers) 150 151 # validation 152 if not args.eval_all: 153 names, values = validate(net, val_data, ctx, eval_metric, len(val_dataset)) 154 for k, v in zip(names, values): 155 print(k, v) 156 else: 157 saved_models = glob.glob(args.save_prefix + '*.params') 158 for epoch, saved_model in enumerate(sorted(saved_models)): 159 print('[Epoch {}] Validating from {}'.format(epoch, saved_model)) 160 net.load_parameters(saved_model) 161 net.collect_params().reset_ctx(ctx) 162 map_name, mean_ap = validate(net, val_data, ctx, eval_metric, len(val_dataset)) 163 val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)]) 164 print('[Epoch {}] Validation: \n{}'.format(epoch, val_msg)) 165 current_map = float(mean_ap[-1]) 166 with open(args.save_prefix + '_best_map.log', 'a') as f: 167 f.write('\n{:04d}:\t{:.4f}'.format(epoch, current_map)) 168