1from __future__ import division
2from __future__ import print_function
3
4import os
5# disable autotune
6os.environ['MXNET_CUDNN_AUTOTUNE_DEFAULT'] = '0'
7import argparse
8import glob
9import logging
10logging.basicConfig(level=logging.INFO)
11import time
12import numpy as np
13import mxnet as mx
14from tqdm import tqdm
15from mxnet import nd
16from mxnet import gluon
17import gluoncv as gcv
18gcv.utils.check_version('0.6.0')
19from gluoncv import data as gdata
20from gluoncv.data import batchify
21from gluoncv.data.transforms.presets.rcnn import FasterRCNNDefaultValTransform
22from gluoncv.utils.metrics.voc_detection import VOC07MApMetric
23from gluoncv.utils.metrics.coco_detection import COCODetectionMetric
24
25def parse_args():
26    parser = argparse.ArgumentParser(description='Validate Faster-RCNN networks.')
27    parser.add_argument('--network', type=str, default='resnet50_v1b',
28                        help="Base feature extraction network name")
29    parser.add_argument('--dataset', type=str, default='voc',
30                        help='Training dataset.')
31    parser.add_argument('--num-workers', '-j', dest='num_workers', type=int,
32                        default=4, help='Number of data workers')
33    parser.add_argument('--gpus', type=str, default='0',
34                        help='Training with GPUs, you can specify 1,3 for example.')
35    parser.add_argument('--pretrained', type=str, default='True',
36                        help='Load weights from previously saved parameters.')
37    parser.add_argument('--save-prefix', type=str, default='',
38                        help='Saving parameter prefix')
39    parser.add_argument('--save-json', action='store_true',
40                        help='Save coco output json')
41    parser.add_argument('--eval-all', action='store_true',
42                        help='Eval all models begins with save prefix. Use with pretrained.')
43    parser.add_argument('--norm-layer', type=str, default=None,
44                        help='Type of normalization layer to use. '
45                             'If set to None, backbone normalization layer will be fixed,'
46                             ' and no normalization layer will be used. '
47                             'Currently supports \'bn\', and None, default is None')
48    parser.add_argument('--use-fpn', action='store_true',
49                        help='Whether to use feature pyramid network.')
50    args = parser.parse_args()
51    return args
52
53def get_dataset(dataset, args):
54    if dataset.lower() == 'voc':
55        val_dataset = gdata.VOCDetection(
56            splits=[(2007, 'test')])
57        val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
58    elif dataset.lower() == 'coco':
59        val_dataset = gdata.COCODetection(splits='instances_val2017', skip_empty=False)
60        val_metric = COCODetectionMetric(val_dataset, args.save_prefix + '_eval',
61                                         cleanup=not args.save_json)
62    else:
63        raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
64    return val_dataset, val_metric
65
66def get_dataloader(net, val_dataset, batch_size, num_workers):
67    """Get dataloader."""
68    val_bfn = batchify.Tuple(*[batchify.Append() for _ in range(3)])
69    val_loader = mx.gluon.data.DataLoader(
70        val_dataset.transform(FasterRCNNDefaultValTransform(net.short, net.max_size)),
71        batch_size, False, batchify_fn=val_bfn, last_batch='keep', num_workers=num_workers)
72    return val_loader
73
74def split_and_load(batch, ctx_list):
75    """Split data to 1 batch each device."""
76    num_ctx = len(ctx_list)
77    new_batch = []
78    for i, data in enumerate(batch):
79        new_data = [x.as_in_context(ctx) for x, ctx in zip(data, ctx_list)]
80        new_batch.append(new_data)
81    return new_batch
82
83def validate(net, val_data, ctx, eval_metric, size):
84    """Test on validation dataset."""
85    clipper = gcv.nn.bbox.BBoxClipToImage()
86    eval_metric.reset()
87    net.hybridize(static_alloc=True)
88    with tqdm(total=size) as pbar:
89        for ib, batch in enumerate(val_data):
90            batch = split_and_load(batch, ctx_list=ctx)
91            det_bboxes = []
92            det_ids = []
93            det_scores = []
94            gt_bboxes = []
95            gt_ids = []
96            gt_difficults = []
97            for x, y, im_scale in zip(*batch):
98                # get prediction results
99                ids, scores, bboxes = net(x)
100                det_ids.append(ids)
101                det_scores.append(scores)
102                # clip to image size
103                det_bboxes.append(clipper(bboxes, x))
104                # rescale to original resolution
105                im_scale = im_scale.reshape((-1)).asscalar()
106                det_bboxes[-1] *= im_scale
107                # split ground truths
108                gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
109                gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
110                gt_bboxes[-1] *= im_scale
111                gt_difficults.append(y.slice_axis(axis=-1, begin=5, end=6) if y.shape[-1] > 5 else None)
112            # update metric
113            for det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff in zip(det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults):
114                eval_metric.update(det_bbox, det_id, det_score, gt_bbox, gt_id, gt_diff)
115            pbar.update(len(ctx))
116    return eval_metric.get()
117
118if __name__ == '__main__':
119    args = parse_args()
120
121    # contexts
122    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
123    ctx = ctx if ctx else [mx.cpu()]
124    args.batch_size = len(ctx)  # 1 batch per device
125
126    # network
127    kwargs = {}
128    module_list = []
129    if args.use_fpn:
130        module_list.append('fpn')
131    if args.norm_layer is not None:
132        module_list.append(args.norm_layer)
133        if args.norm_layer == 'bn':
134            kwargs['num_devices'] = len(args.gpus.split(','))
135    net_name = '_'.join(('faster_rcnn', *module_list, args.network, args.dataset))
136    args.save_prefix += net_name
137    if args.pretrained.lower() in ['true', '1', 'yes', 't']:
138        net = gcv.model_zoo.get_model(net_name, pretrained=True, **kwargs)
139    else:
140        net = gcv.model_zoo.get_model(net_name, pretrained=False, **kwargs)
141        net.load_parameters(args.pretrained.strip(), cast_dtype=True)
142    net.collect_params().reset_ctx(ctx)
143
144    # validation data
145    val_dataset, eval_metric = get_dataset(args.dataset, args)
146    val_data = get_dataloader(
147        net, val_dataset, args.batch_size, args.num_workers)
148
149    # validation
150    if not args.eval_all:
151        names, values = validate(net, val_data, ctx, eval_metric, len(val_dataset))
152        for k, v in zip(names, values):
153            print(k, v)
154    else:
155        saved_models = glob.glob(args.save_prefix + '*.params')
156        for epoch, saved_model in enumerate(sorted(saved_models)):
157            print('[Epoch {}] Validating from {}'.format(epoch, saved_model))
158            net.load_parameters(saved_model)
159            net.collect_params().reset_ctx(ctx)
160            map_name, mean_ap = validate(net, val_data, ctx, eval_metric, len(val_dataset))
161            val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
162            print('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
163            current_map = float(mean_ap[-1])
164            with open(args.save_prefix+'_best_map.log', 'a') as f:
165                f.write('\n{:04d}:\t{:.4f}'.format(epoch, current_map))
166