1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import mxnet as mx 19import numpy as np 20 21def conv_act_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), \ 22 stride=(1,1), act_type="relu", use_batchnorm=False): 23 """ 24 wrapper for a small Convolution group 25 26 Parameters: 27 ---------- 28 from_layer : mx.symbol 29 continue on which layer 30 name : str 31 base name of the new layers 32 num_filter : int 33 how many filters to use in Convolution layer 34 kernel : tuple (int, int) 35 kernel size (h, w) 36 pad : tuple (int, int) 37 padding size (h, w) 38 stride : tuple (int, int) 39 stride size (h, w) 40 act_type : str 41 activation type, can be relu... 42 use_batchnorm : bool 43 whether to use batch normalization 44 45 Returns: 46 ---------- 47 (conv, relu) mx.Symbols 48 """ 49 conv = mx.symbol.Convolution(data=from_layer, kernel=kernel, pad=pad, \ 50 stride=stride, num_filter=num_filter, name="{}_conv".format(name)) 51 if use_batchnorm: 52 conv = mx.symbol.BatchNorm(data=conv, name="{}_bn".format(name)) 53 relu = mx.symbol.Activation(data=conv, act_type=act_type, \ 54 name="{}_{}".format(name, act_type)) 55 return relu 56 57def legacy_conv_act_layer(from_layer, name, num_filter, kernel=(1,1), pad=(0,0), \ 58 stride=(1,1), act_type="relu", use_batchnorm=False): 59 """ 60 wrapper for a small Convolution group 61 62 Parameters: 63 ---------- 64 from_layer : mx.symbol 65 continue on which layer 66 name : str 67 base name of the new layers 68 num_filter : int 69 how many filters to use in Convolution layer 70 kernel : tuple (int, int) 71 kernel size (h, w) 72 pad : tuple (int, int) 73 padding size (h, w) 74 stride : tuple (int, int) 75 stride size (h, w) 76 act_type : str 77 activation type, can be relu... 78 use_batchnorm : bool 79 whether to use batch normalization 80 81 Returns: 82 ---------- 83 (conv, relu) mx.Symbols 84 """ 85 assert not use_batchnorm, "batchnorm not yet supported" 86 bias = mx.symbol.Variable(name="conv{}_bias".format(name), 87 init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0'}) 88 conv = mx.symbol.Convolution(data=from_layer, bias=bias, kernel=kernel, pad=pad, \ 89 stride=stride, num_filter=num_filter, name="conv{}".format(name)) 90 relu = mx.symbol.Activation(data=conv, act_type=act_type, \ 91 name="{}{}".format(act_type, name)) 92 if use_batchnorm: 93 relu = mx.symbol.BatchNorm(data=relu, name="bn{}".format(name)) 94 return conv, relu 95 96def multi_layer_feature(body, from_layers, num_filters, strides, pads, min_filter=128): 97 """Wrapper function to extract features from base network, attaching extra 98 layers and SSD specific layers 99 100 Parameters 101 ---------- 102 from_layers : list of str 103 feature extraction layers, use '' for add extra layers 104 For example: 105 from_layers = ['relu4_3', 'fc7', '', '', '', ''] 106 which means extract feature from relu4_3 and fc7, adding 4 extra layers 107 on top of fc7 108 num_filters : list of int 109 number of filters for extra layers, you can use -1 for extracted features, 110 however, if normalization and scale is applied, the number of filter for 111 that layer must be provided. 112 For example: 113 num_filters = [512, -1, 512, 256, 256, 256] 114 strides : list of int 115 strides for the 3x3 convolution appended, -1 can be used for extracted 116 feature layers 117 pads : list of int 118 paddings for the 3x3 convolution, -1 can be used for extracted layers 119 min_filter : int 120 minimum number of filters used in 1x1 convolution 121 122 Returns 123 ------- 124 list of mx.Symbols 125 126 """ 127 # arguments check 128 assert len(from_layers) > 0 129 assert isinstance(from_layers[0], str) and len(from_layers[0].strip()) > 0 130 assert len(from_layers) == len(num_filters) == len(strides) == len(pads) 131 132 internals = body.get_internals() 133 layers = [] 134 for k, params in enumerate(zip(from_layers, num_filters, strides, pads)): 135 from_layer, num_filter, s, p = params 136 if from_layer.strip(): 137 # extract from base network 138 layer = internals[from_layer.strip() + '_output'] 139 layers.append(layer) 140 else: 141 # attach from last feature layer 142 assert len(layers) > 0 143 assert num_filter > 0 144 layer = layers[-1] 145 num_1x1 = max(min_filter, num_filter // 2) 146 conv_1x1 = conv_act_layer(layer, 'multi_feat_%d_conv_1x1' % (k), 147 num_1x1, kernel=(1, 1), pad=(0, 0), stride=(1, 1), act_type='relu') 148 conv_3x3 = conv_act_layer(conv_1x1, 'multi_feat_%d_conv_3x3' % (k), 149 num_filter, kernel=(3, 3), pad=(p, p), stride=(s, s), act_type='relu') 150 layers.append(conv_3x3) 151 return layers 152 153def multibox_layer(from_layers, num_classes, sizes=[.2, .95], 154 ratios=[1], normalization=-1, num_channels=[], 155 clip=False, interm_layer=0, steps=[]): 156 """ 157 the basic aggregation module for SSD detection. Takes in multiple layers, 158 generate multiple object detection targets by customized layers 159 160 Parameters: 161 ---------- 162 from_layers : list of mx.symbol 163 generate multibox detection from layers 164 num_classes : int 165 number of classes excluding background, will automatically handle 166 background in this function 167 sizes : list or list of list 168 [min_size, max_size] for all layers or [[], [], []...] for specific layers 169 ratios : list or list of list 170 [ratio1, ratio2...] for all layers or [[], [], ...] for specific layers 171 normalizations : int or list of int 172 use normalizations value for all layers or [...] for specific layers, 173 -1 indicate no normalizations and scales 174 num_channels : list of int 175 number of input layer channels, used when normalization is enabled, the 176 length of list should equals to number of normalization layers 177 clip : bool 178 whether to clip out-of-image boxes 179 interm_layer : int 180 if > 0, will add a intermediate Convolution layer 181 steps : list 182 specify steps for each MultiBoxPrior layer, leave empty, it will calculate 183 according to layer dimensions 184 185 Returns: 186 ---------- 187 list of outputs, as [loc_preds, cls_preds, anchor_boxes] 188 loc_preds : localization regression prediction 189 cls_preds : classification prediction 190 anchor_boxes : generated anchor boxes 191 """ 192 assert len(from_layers) > 0, "from_layers must not be empty list" 193 assert num_classes > 0, \ 194 "num_classes {} must be larger than 0".format(num_classes) 195 196 assert len(ratios) > 0, "aspect ratios must not be empty list" 197 if not isinstance(ratios[0], list): 198 # provided only one ratio list, broadcast to all from_layers 199 ratios = [ratios] * len(from_layers) 200 assert len(ratios) == len(from_layers), \ 201 "ratios and from_layers must have same length" 202 203 assert len(sizes) > 0, "sizes must not be empty list" 204 if len(sizes) == 2 and not isinstance(sizes[0], list): 205 # provided size range, we need to compute the sizes for each layer 206 assert sizes[0] > 0 and sizes[0] < 1 207 assert sizes[1] > 0 and sizes[1] < 1 and sizes[1] > sizes[0] 208 tmp = np.linspace(sizes[0], sizes[1], num=(len(from_layers)-1)) 209 # Ref for start_offset value: 210 # https://arxiv.org/abs/1512.02325 211 start_offset = 0.1 212 min_sizes = [start_offset] + tmp.tolist() 213 max_sizes = tmp.tolist() + [tmp[-1]+start_offset] 214 sizes = zip(min_sizes, max_sizes) 215 assert len(sizes) == len(from_layers), \ 216 "sizes and from_layers must have same length" 217 218 if not isinstance(normalization, list): 219 normalization = [normalization] * len(from_layers) 220 assert len(normalization) == len(from_layers) 221 222 assert sum(x > 0 for x in normalization) <= len(num_channels), \ 223 "must provide number of channels for each normalized layer" 224 225 if steps: 226 assert len(steps) == len(from_layers), "provide steps for all layers or leave empty" 227 228 loc_pred_layers = [] 229 cls_pred_layers = [] 230 anchor_layers = [] 231 num_classes += 1 # always use background as label 0 232 233 for k, from_layer in enumerate(from_layers): 234 from_name = from_layer.name 235 # normalize 236 if normalization[k] > 0: 237 from_layer = mx.symbol.L2Normalization(data=from_layer, \ 238 mode="channel", name="{}_norm".format(from_name)) 239 scale = mx.symbol.Variable(name="{}_scale".format(from_name), 240 shape=(1, num_channels.pop(0), 1, 1), 241 init=mx.init.Constant(normalization[k]), 242 attr={'__wd_mult__': '0.1'}) 243 from_layer = mx.symbol.broadcast_mul(lhs=scale, rhs=from_layer) 244 if interm_layer > 0: 245 from_layer = mx.symbol.Convolution(data=from_layer, kernel=(3,3), \ 246 stride=(1,1), pad=(1,1), num_filter=interm_layer, \ 247 name="{}_inter_conv".format(from_name)) 248 from_layer = mx.symbol.Activation(data=from_layer, act_type="relu", \ 249 name="{}_inter_relu".format(from_name)) 250 251 # estimate number of anchors per location 252 # here I follow the original version in caffe 253 # TODO: better way to shape the anchors?? 254 size = sizes[k] 255 assert len(size) > 0, "must provide at least one size" 256 size_str = "(" + ",".join([str(x) for x in size]) + ")" 257 ratio = ratios[k] 258 assert len(ratio) > 0, "must provide at least one ratio" 259 ratio_str = "(" + ",".join([str(x) for x in ratio]) + ")" 260 num_anchors = len(size) -1 + len(ratio) 261 262 # create location prediction layer 263 num_loc_pred = num_anchors * 4 264 bias = mx.symbol.Variable(name="{}_loc_pred_conv_bias".format(from_name), 265 init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0'}) 266 loc_pred = mx.symbol.Convolution(data=from_layer, bias=bias, kernel=(3,3), \ 267 stride=(1,1), pad=(1,1), num_filter=num_loc_pred, \ 268 name="{}_loc_pred_conv".format(from_name)) 269 loc_pred = mx.symbol.transpose(loc_pred, axes=(0,2,3,1)) 270 loc_pred = mx.symbol.Flatten(data=loc_pred) 271 loc_pred_layers.append(loc_pred) 272 273 # create class prediction layer 274 num_cls_pred = num_anchors * num_classes 275 bias = mx.symbol.Variable(name="{}_cls_pred_conv_bias".format(from_name), 276 init=mx.init.Constant(0.0), attr={'__lr_mult__': '2.0'}) 277 cls_pred = mx.symbol.Convolution(data=from_layer, bias=bias, kernel=(3,3), \ 278 stride=(1,1), pad=(1,1), num_filter=num_cls_pred, \ 279 name="{}_cls_pred_conv".format(from_name)) 280 cls_pred = mx.symbol.transpose(cls_pred, axes=(0,2,3,1)) 281 cls_pred = mx.symbol.Flatten(data=cls_pred) 282 cls_pred_layers.append(cls_pred) 283 284 # create anchor generation layer 285 if steps: 286 step = (steps[k], steps[k]) 287 else: 288 step = '(-1.0, -1.0)' 289 anchors = mx.symbol.contrib.MultiBoxPrior(from_layer, sizes=size_str, ratios=ratio_str, 290 clip=clip, name="{}_anchors".format(from_name), 291 steps=step) 292 anchors = mx.symbol.Flatten(data=anchors) 293 anchor_layers.append(anchors) 294 295 loc_preds = mx.symbol.Concat(*loc_pred_layers, num_args=len(loc_pred_layers), \ 296 dim=1, name="multibox_loc_pred") 297 cls_preds = mx.symbol.Concat(*cls_pred_layers, num_args=len(cls_pred_layers), \ 298 dim=1) 299 cls_preds = mx.symbol.Reshape(data=cls_preds, shape=(0, -1, num_classes)) 300 cls_preds = mx.symbol.transpose(cls_preds, axes=(0, 2, 1), name="multibox_cls_pred") 301 anchor_boxes = mx.symbol.Concat(*anchor_layers, \ 302 num_args=len(anchor_layers), dim=1) 303 anchor_boxes = mx.symbol.Reshape(data=anchor_boxes, shape=(0, -1, 4), name="multibox_anchors") 304 return [loc_preds, cls_preds, anchor_boxes] 305