1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18from __future__ import print_function 19import os 20import sys 21import importlib 22import mxnet as mx 23from dataset.iterator import DetRecordIter 24from config.config import cfg 25from evaluate.eval_metric import MApMetric, VOC07MApMetric 26import logging 27import time 28from symbol.symbol_factory import get_symbol 29from symbol import symbol_builder 30from mxnet.base import SymbolHandle, check_call, _LIB, mx_uint, c_str_array 31import ctypes 32from mxnet.contrib.quantization import * 33 34def evaluate_net(net, path_imgrec, num_classes, num_batch, mean_pixels, data_shape, 35 model_prefix, epoch, ctx=mx.cpu(), batch_size=32, 36 path_imglist="", nms_thresh=0.45, force_nms=False, 37 ovp_thresh=0.5, use_difficult=False, class_names=None, 38 voc07_metric=False): 39 """ 40 evalute network given validation record file 41 42 Parameters: 43 ---------- 44 net : str or None 45 Network name or use None to load from json without modifying 46 path_imgrec : str 47 path to the record validation file 48 path_imglist : str 49 path to the list file to replace labels in record file, optional 50 num_classes : int 51 number of classes, not including background 52 mean_pixels : tuple 53 (mean_r, mean_g, mean_b) 54 data_shape : tuple or int 55 (3, height, width) or height/width 56 model_prefix : str 57 model prefix of saved checkpoint 58 epoch : int 59 load model epoch 60 ctx : mx.ctx 61 mx.gpu() or mx.cpu() 62 batch_size : int 63 validation batch size 64 nms_thresh : float 65 non-maximum suppression threshold 66 force_nms : boolean 67 whether suppress different class objects 68 ovp_thresh : float 69 AP overlap threshold for true/false postives 70 use_difficult : boolean 71 whether to use difficult objects in evaluation if applicable 72 class_names : comma separated str 73 class names in string, must correspond to num_classes if set 74 voc07_metric : boolean 75 whether to use 11-point evluation as in VOC07 competition 76 """ 77 # set up logger 78 logging.basicConfig() 79 logger = logging.getLogger() 80 logger.setLevel(logging.INFO) 81 82 # args 83 if isinstance(data_shape, int): 84 data_shape = (3, data_shape, data_shape) 85 assert len(data_shape) == 3 and data_shape[0] == 3 86 model_prefix += '_' + str(data_shape[1]) 87 88 # iterator 89 eval_iter = DetRecordIter(path_imgrec, batch_size, data_shape, mean_pixels=mean_pixels, 90 path_imglist=path_imglist, **cfg.valid) 91 # model params 92 load_net, args, auxs = mx.model.load_checkpoint(model_prefix, epoch) 93 # network 94 if net is None: 95 net = load_net 96 else: 97 net = get_symbol(net, data_shape[1], num_classes=num_classes, 98 nms_thresh=nms_thresh, force_suppress=force_nms) 99 if not 'label' in net.list_arguments(): 100 label = mx.sym.Variable(name='label') 101 net = mx.sym.Group([net, label]) 102 103 # init module 104 mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx, 105 fixed_param_names=net.list_arguments()) 106 mod.bind(data_shapes=eval_iter.provide_data, label_shapes=eval_iter.provide_label) 107 mod.set_params(args, auxs, allow_missing=False, force_init=True) 108 109 # run evaluation 110 if voc07_metric: 111 metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names) 112 else: 113 metric = MApMetric(ovp_thresh, use_difficult, class_names) 114 115 num = num_batch * batch_size 116 data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes] 117 batch = mx.io.DataBatch(data, []) # empty label 118 119 dry_run = 5 # use 5 iterations to warm up 120 for i in range(dry_run): 121 mod.forward(batch, is_train=False) 122 for output in mod.get_outputs(): 123 output.wait_to_read() 124 125 tic = time.time() 126 results = mod.score(eval_iter, metric, num_batch=num_batch) 127 speed = num / (time.time() - tic) 128 if logger is not None: 129 logger.info('Finished inference with %d images' % num) 130 logger.info('Finished with %f images per second', speed) 131 132 for k, v in results: 133 print("{}: {}".format(k, v)) 134