1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18from __future__ import print_function
19import os
20import sys
21import importlib
22import mxnet as mx
23from dataset.iterator import DetRecordIter
24from config.config import cfg
25from evaluate.eval_metric import MApMetric, VOC07MApMetric
26import logging
27import time
28from symbol.symbol_factory import get_symbol
29from symbol import symbol_builder
30from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array
31import ctypes
32from mxnet.contrib.quantization import *
33
34def evaluate_net(net, path_imgrec, num_classes, num_batch, mean_pixels, data_shape,
35                 model_prefix, epoch, ctx=mx.cpu(), batch_size=32,
36                 path_imglist="", nms_thresh=0.45, force_nms=False,
37                 ovp_thresh=0.5, use_difficult=False, class_names=None,
38                 voc07_metric=False):
39    """
40    evalute network given validation record file
41
42    Parameters:
43    ----------
44    net : str or None
45        Network name or use None to load from json without modifying
46    path_imgrec : str
47        path to the record validation file
48    path_imglist : str
49        path to the list file to replace labels in record file, optional
50    num_classes : int
51        number of classes, not including background
52    mean_pixels : tuple
53        (mean_r, mean_g, mean_b)
54    data_shape : tuple or int
55        (3, height, width) or height/width
56    model_prefix : str
57        model prefix of saved checkpoint
58    epoch : int
59        load model epoch
60    ctx : mx.ctx
61        mx.gpu() or mx.cpu()
62    batch_size : int
63        validation batch size
64    nms_thresh : float
65        non-maximum suppression threshold
66    force_nms : boolean
67        whether suppress different class objects
68    ovp_thresh : float
69        AP overlap threshold for true/false postives
70    use_difficult : boolean
71        whether to use difficult objects in evaluation if applicable
72    class_names : comma separated str
73        class names in string, must correspond to num_classes if set
74    voc07_metric : boolean
75        whether to use 11-point evluation as in VOC07 competition
76    """
77    # set up logger
78    logging.basicConfig()
79    logger = logging.getLogger()
80    logger.setLevel(logging.INFO)
81
82    # args
83    if isinstance(data_shape, int):
84        data_shape = (3, data_shape, data_shape)
85    assert len(data_shape) == 3 and data_shape[0] == 3
86    model_prefix += '_' + str(data_shape[1])
87
88    # iterator
89    eval_iter = DetRecordIter(path_imgrec, batch_size, data_shape, mean_pixels=mean_pixels,
90                              path_imglist=path_imglist, **cfg.valid)
91    # model params
92    load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch)
93    # network
94    if net is None:
95        net = load_net
96    else:
97        net = get_symbol(net, data_shape[1], num_classes=num_classes,
98            nms_thresh=nms_thresh, force_suppress=force_nms)
99    if not 'label' in net.list_arguments():
100        label = mx.sym.Variable(name='label')
101        net = mx.sym.Group([net, label])
102
103    # init module
104    mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
105        fixed_param_names=net.list_arguments())
106    mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label)
107    mod.set_params(args, auxs, allow_missing=False, force_init=True)
108
109    # run evaluation
110    if voc07_metric:
111        metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names)
112    else:
113        metric = MApMetric(ovp_thresh, use_difficult, class_names)
114
115    num = num_batch * batch_size
116    data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes]
117    batch = mx.io.DataBatch(data, [])  # empty label
118
119    dry_run = 5                 # use 5 iterations to warm up
120    for i in range(dry_run):
121        mod.forward(batch, is_train=False)
122        for output in mod.get_outputs():
123            output.wait_to_read()
124
125    tic = time.time()
126    results = mod.score(eval_iter, metric, num_batch=num_batch)
127    speed = num / (time.time() - tic)
128    if logger is not None:
129        logger.info('Finished inference with %d images' % num)
130        logger.info('Finished with %f images per second', speed)
131
132    for k, v in results:
133        print("{}: {}".format(k, v))
134