1from __future__ import division
2
3import os
4
5# disable autotune
6os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0'
7import argparse
8import glob
9import logging
10
11logging.basicConfig(level=logging.INFO)
12import numpy as np
13import mxnet as mx
14from tqdm import tqdm
15import gluoncv as gcv
16gcv.utils.check_version('0.6.0')
17from gluoncv import data as gdata
18from gluoncv.data import batchify
19from gluoncv.data.transforms.presets.rcnn import MaskRCNNDefaultValTransform
20from gluoncv.utils.metrics.coco_instance import COCOInstanceMetric
21
22
23def parse_args():
24    parser = argparse.ArgumentParser(description='Validate Mask RCNN networks.')
25    parser.add_argument('--network', type=str, default='resnet50_v1b',
26                        help="Base feature extraction network name")
27    parser.add_argument('--dataset', type=str, default='coco',
28                        help='Training dataset.')
29    parser.add_argument('--num-workers', '-j', dest='num_workers', type=int,
30                        default=4, help='Number of data workers')
31    parser.add_argument('--gpus', type=str, default='0',
32                        help='Training with GPUs, you can specify 1,3 for example.')
33    parser.add_argument('--pretrained', type=str, default='True',
34                        help='Load weights from previously saved parameters.')
35    parser.add_argument('--save-prefix', type=str, default='',
36                        help='Saving parameter prefix')
37    parser.add_argument('--save-json', action='store_true',
38                        help='Save coco output json')
39    parser.add_argument('--eval-all', action='store_true',
40                        help='Eval all models begins with save prefix. Use with pretrained.')
41    parser.add_argument('--use-fpn', action='store_true',
42                        help='Whether to load model with feature pyramid network.')
43    args = parser.parse_args()
44    return args
45
46
47def get_dataset(dataset, args):
48    if dataset.lower() == 'coco':
49        val_dataset = gdata.COCOInstance(splits='instances_val2017', skip_empty=False)
50        val_metric = COCOInstanceMetric(val_dataset, args.save_prefix + '_eval',
51                                        cleanup=not args.save_json)
52    else:
53        raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
54    return val_dataset, val_metric
55
56
57def get_dataloader(net, val_dataset, batch_size, num_workers):
58    """Get dataloader."""
59    val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(2)])
60    val_loader = mx.gluon.data.DataLoader(
61        val_dataset.transform(MaskRCNNDefaultValTransform(net.short, net.max_size)),
62        batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
63    return val_loader
64
65
66def split_and_load(batch, ctx_list):
67    """Split data to 1 batch each device."""
68    num_ctx = len(ctx_list)
69    new_batch = []
70    for i, data in enumerate(batch):
71        new_data = [x.as_in_context(ctx) for x, ctx in zip(data, ctx_list)]
72        new_batch.append(new_data)
73    return new_batch
74
75
76def validate(net, val_data, ctx, eval_metric, size):
77    """Test on validation dataset."""
78    clipper = gcv.nn.bbox.BBoxClipToImage()
79    eval_metric.reset()
80    net.hybridize(static_alloc=True)
81    with tqdm(total=size) as pbar:
82        for ib, batch in enumerate(val_data):
83            batch = split_and_load(batch, ctx_list=ctx)
84            det_bboxes = []
85            det_ids = []
86            det_scores = []
87            det_masks = []
88            det_infos = []
89            for x, im_info in zip(*batch):
90                # get prediction results
91                ids, scores, bboxes, masks = net(x)
92                det_bboxes.append(clipper(bboxes, x))
93                det_ids.append(ids)
94                det_scores.append(scores)
95                det_masks.append(masks)
96                det_infos.append(im_info)
97            # update metric
98            for det_bbox, det_id, det_score, det_mask, det_info in \
99                    zip(det_bboxes, det_ids, det_scores, det_masks, det_infos):
100                for i in range(det_info.shape[0]):
101                    # numpy everything
102                    det_bbox = det_bbox[i].asnumpy()
103                    det_id = det_id[i].asnumpy()
104                    det_score = det_score[i].asnumpy()
105                    det_mask = det_mask[i].asnumpy()
106                    det_info = det_info[i].asnumpy()
107                    # filter by conf threshold
108                    im_height, im_width, im_scale = det_info
109                    valid = np.where(((det_id >= 0) & (det_score >= 0.001)))[0]
110                    det_id = det_id[valid]
111                    det_score = det_score[valid]
112                    det_bbox = det_bbox[valid] / im_scale
113                    det_mask = det_mask[valid]
114                    # fill full mask
115                    im_height, im_width = int(round(im_height / im_scale)), int(
116                        round(im_width / im_scale))
117                    full_masks = gcv.data.transforms.mask.fill(det_mask, det_bbox,
118                                                               (im_width, im_height),
119                                                               fast_fill=False)
120                    eval_metric.update(det_bbox, det_id, det_score, full_masks)
121            pbar.update(len(ctx))
122    return eval_metric.get()
123
124
125if __name__ == '__main__':
126    args = parse_args()
127
128    # training contexts
129    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
130    ctx = ctx if ctx else [mx.cpu()]
131    args.batch_size = len(ctx)  # 1 batch per device
132
133    # network
134    module_list = []
135    if args.use_fpn:
136        module_list.append('fpn')
137    net_name = '_'.join(('mask_rcnn', *module_list, args.network, args.dataset))
138    args.save_prefix += net_name
139    if args.pretrained.lower() in ['true', '1', 'yes', 't']:
140        net = gcv.model_zoo.get_model(net_name, pretrained=True)
141    else:
142        net = gcv.model_zoo.get_model(net_name, pretrained=False)
143        net.load_parameters(args.pretrained.strip(), cast_dtype=True)
144    net.collect_params().reset_ctx(ctx)
145
146    # training data
147    val_dataset, eval_metric = get_dataset(args.dataset, args)
148    val_data = get_dataloader(
149        net, val_dataset, args.batch_size, args.num_workers)
150
151    # validation
152    if not args.eval_all:
153        names, values = validate(net, val_data, ctx, eval_metric, len(val_dataset))
154        for k, v in zip(names, values):
155            print(k, v)
156    else:
157        saved_models = glob.glob(args.save_prefix + '*.params')
158        for epoch, saved_model in enumerate(sorted(saved_models)):
159            print('[Epoch {}] Validating from {}'.format(epoch, saved_model))
160            net.load_parameters(saved_model)
161            net.collect_params().reset_ctx(ctx)
162            map_name, mean_ap = validate(net, val_data, ctx, eval_metric, len(val_dataset))
163            val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
164            print('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
165            current_map = float(mean_ap[-1])
166            with open(args.save_prefix + '_best_map.log', 'a') as f:
167                f.write('\n{:04d}:\t{:.4f}'.format(epoch, current_map))
168