1# Modified work: 2# ----------------------------------------------------------------------------- 3# Copyright (c) 2015 Preferred Infrastructure, Inc. 4# Copyright (c) 2015 Preferred Networks, Inc. 5# ----------------------------------------------------------------------------- 6 7# Original work of _roi_pooling_slice, forward_cpu and backward_cpu: 8# ----------------------------------------------------------------------------- 9# Copyright 2014 Nervana Systems Inc. 10# Licensed under the Apache License, Version 2.0 (the "License"); 11# you may not use this file except in compliance with the License. 12# You may obtain a copy of the License at 13# 14# https://www.apache.org/licenses/LICENSE-2.0 15# 16# Unless required by applicable law or agreed to in writing, software 17# distributed under the License is distributed on an "AS IS" BASIS, 18# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19# See the License for the specific language governing permissions and 20# limitations under the License. 21# ----------------------------------------------------------------------------- 22 23# Original work of forward_gpu and backward_gpu: 24# ----------------------------------------------------------------------------- 25# Fast R-CNN 26# Copyright (c) 2015 Microsoft 27# Licensed under The MIT License [see fast-rcnn/LICENSE for details] 28# Written by Ross Girshick 29# ----------------------------------------------------------------------------- 30 31import numbers 32import numpy 33import six 34 35import chainer 36from chainer.backends import cuda 37from chainer import function 38from chainer import utils 39from chainer.utils import type_check 40 41from chainer.functions.pooling.roi_pooling_2d import _roi_pooling_slice 42 43 44def _pair(x): 45 if isinstance(x, chainer.utils.collections_abc.Iterable): 46 return x 47 return x, x 48 49 50class ROIMaxPooling2D(function.Function): 51 52 """RoI max pooling over a set of 2d planes.""" 53 54 def __init__(self, outsize, spatial_scale): 55 outh, outw = _pair(outsize) 56 if not (isinstance(outh, numbers.Integral) and outh > 0): 57 raise TypeError( 58 'outsize[0] must be positive integer: {}, {}' 59 .format(type(outh), outh)) 60 if not (isinstance(outw, numbers.Integral) and outw > 0): 61 raise TypeError( 62 'outsize[1] must be positive integer: {}, {}' 63 .format(type(outw), outw)) 64 if isinstance(spatial_scale, numbers.Integral): 65 spatial_scale = float(spatial_scale) 66 if not (isinstance(spatial_scale, numbers.Real) and 67 spatial_scale > 0): 68 raise TypeError( 69 'spatial_scale must be a positive float number: {}, {}' 70 .format(type(spatial_scale), spatial_scale)) 71 self.outh, self.outw = outh, outw 72 self.spatial_scale = spatial_scale 73 74 def check_type_forward(self, in_types): 75 type_check.expect(in_types.size() == 3) 76 77 x_type, roi_type, roi_index_type = in_types 78 type_check.expect( 79 x_type.dtype.kind == 'f', 80 x_type.ndim == 4, 81 x_type.dtype == roi_type.dtype, 82 roi_type.ndim == 2, 83 roi_type.shape[1] == 4, 84 roi_index_type.dtype == numpy.int32, 85 roi_index_type.ndim == 1, 86 roi_type.shape[0] == roi_index_type.shape[0], 87 ) 88 89 def forward_cpu(self, inputs): 90 self.retain_inputs((1, 2)) 91 self._bottom_data_shape = inputs[0].shape 92 93 bottom_data, bottom_rois, bottom_roi_indices = inputs 94 channels, height, width = bottom_data.shape[1:] 95 n_rois = bottom_rois.shape[0] 96 top_data = numpy.full( 97 (n_rois, channels, self.outh, self.outw), 98 - numpy.inf, dtype=bottom_data.dtype) 99 self.argmax_data = - numpy.ones(top_data.shape, numpy.int32) 100 101 for i_roi in six.moves.range(n_rois): 102 idx = bottom_roi_indices[i_roi] 103 ymin, xmin, ymax, xmax = bottom_rois[i_roi] 104 ymin = int(round(ymin * self.spatial_scale)) 105 xmin = int(round(xmin * self.spatial_scale)) 106 ymax = int(round(ymax * self.spatial_scale)) 107 xmax = int(round(xmax * self.spatial_scale)) 108 roi_height = max(ymax - ymin, 1) 109 roi_width = max(xmax - xmin, 1) 110 strideh = 1. * roi_height / self.outh 111 stridew = 1. * roi_width / self.outw 112 113 for outh in six.moves.range(self.outh): 114 sliceh, lenh = _roi_pooling_slice( 115 outh, strideh, height, ymin) 116 if sliceh.stop <= sliceh.start: 117 continue 118 for outw in six.moves.range(self.outw): 119 slicew, lenw = _roi_pooling_slice( 120 outw, stridew, width, xmin) 121 if slicew.stop <= slicew.start: 122 continue 123 roi_data = bottom_data[int(idx), :, sliceh, slicew]\ 124 .reshape(channels, -1) 125 top_data[i_roi, :, outh, outw] =\ 126 numpy.max(roi_data, axis=1) 127 128 # get the max idx respect to feature_maps coordinates 129 max_idx_slice = numpy.unravel_index( 130 numpy.argmax(roi_data, axis=1), (lenh, lenw)) 131 max_idx_slice_h = max_idx_slice[0] + sliceh.start 132 max_idx_slice_w = max_idx_slice[1] + slicew.start 133 max_idx_slice = max_idx_slice_h * width + max_idx_slice_w 134 self.argmax_data[i_roi, :, outh, outw] = max_idx_slice 135 return top_data, 136 137 def forward_gpu(self, inputs): 138 self.retain_inputs((1, 2)) 139 self._bottom_data_shape = inputs[0].shape 140 141 bottom_data, bottom_rois, bottom_roi_indices = inputs 142 channels, height, width = bottom_data.shape[1:] 143 n_rois = bottom_rois.shape[0] 144 top_data = cuda.cupy.empty((n_rois, channels, self.outh, 145 self.outw), dtype=bottom_data.dtype) 146 self.argmax_data = cuda.cupy.empty(top_data.shape, numpy.int32) 147 cuda.elementwise( 148 ''' 149 raw T bottom_data, raw T bottom_rois, raw int32 bottom_roi_indices, 150 T spatial_scale, int32 channels, int32 height, int32 width, 151 int32 pooled_height, int32 pooled_width 152 ''', 153 'T top_data, int32 argmax_data', 154 ''' 155 // pos in output filter 156 int pw = i % pooled_width; 157 int ph = (i / pooled_width) % pooled_height; 158 int c = (i / pooled_width / pooled_height) % channels; 159 int n = i / pooled_width / pooled_height / channels; 160 161 int roi_batch_ind = bottom_roi_indices[n]; 162 int roi_start_h = round(bottom_rois[n * 4 + 0] * spatial_scale); 163 int roi_start_w = round(bottom_rois[n * 4 + 1] * spatial_scale); 164 int roi_end_h = round(bottom_rois[n * 4 + 2] * spatial_scale); 165 int roi_end_w = round(bottom_rois[n * 4 + 3] * spatial_scale); 166 167 // Force malformed ROIs to be 1x1 168 int roi_height = max(roi_end_h - roi_start_h , 1); 169 int roi_width = max(roi_end_w - roi_start_w, 1); 170 T bin_size_h = static_cast<T>(roi_height) 171 / static_cast<T>(pooled_height); 172 T bin_size_w = static_cast<T>(roi_width) 173 / static_cast<T>(pooled_width); 174 175 int hstart = static_cast<int>(floor(static_cast<T>(ph) 176 * bin_size_h)); 177 int wstart = static_cast<int>(floor(static_cast<T>(pw) 178 * bin_size_w)); 179 int hend = static_cast<int>(ceil(static_cast<T>(ph + 1) 180 * bin_size_h)); 181 int wend = static_cast<int>(ceil(static_cast<T>(pw + 1) 182 * bin_size_w)); 183 184 // Add roi offsets and clip to input boundaries 185 hstart = min(max(hstart + roi_start_h, 0), height); 186 hend = min(max(hend + roi_start_h, 0), height); 187 wstart = min(max(wstart + roi_start_w, 0), width); 188 wend = min(max(wend + roi_start_w, 0), width); 189 190 // Define an empty pooling region to be zero 191 T maxval = - (T) (1.0 / 0.0); 192 // If nothing is pooled, argmax=-1 causes nothing to be backprop'd 193 int maxidx = -1; 194 int data_offset = (roi_batch_ind * channels + c) * height * width; 195 for (int h = hstart; h < hend; ++h) { 196 for (int w = wstart; w < wend; ++w) { 197 int bottom_index = h * width + w; 198 if (bottom_data[data_offset + bottom_index] > maxval) { 199 maxval = bottom_data[data_offset + bottom_index]; 200 maxidx = bottom_index; 201 } 202 } 203 } 204 top_data = maxval; 205 argmax_data = maxidx; 206 ''', 'roi_max_pooling_2d_fwd' 207 )(bottom_data, bottom_rois, bottom_roi_indices, 208 self.spatial_scale, channels, height, width, 209 self.outh, self.outw, top_data, self.argmax_data) 210 211 return top_data, 212 213 def backward_cpu(self, inputs, gy): 214 bottom_rois, bottom_roi_indices = inputs[1:] 215 channels, height, width = self._bottom_data_shape[1:] 216 bottom_diff = numpy.zeros(self._bottom_data_shape, bottom_rois.dtype) 217 218 pooled_height = self.outh 219 pooled_width = self.outw 220 top_diff = gy[0] 221 222 for i in six.moves.range(top_diff.size): 223 pw = i % pooled_width 224 ph = int(i / pooled_width) % pooled_height 225 c = int(i / pooled_width / pooled_height) % channels 226 n = int(i / pooled_width / pooled_height / channels) 227 228 roi_batch_ind = int(bottom_roi_indices[n]) 229 230 max_idx = self.argmax_data[n, c, ph, pw] 231 h = int(max_idx / width) 232 w = max_idx % width 233 if max_idx != -1: 234 bottom_diff[roi_batch_ind, c, h, w] += top_diff[ 235 n, c, ph, pw] 236 return bottom_diff, None, None 237 238 def backward_gpu(self, inputs, gy): 239 utils.nondeterministic('atomicAdd') 240 bottom_rois, bottom_roi_indices = inputs[1:] 241 channels, height, width = self._bottom_data_shape[1:] 242 bottom_diff = cuda.cupy.zeros( 243 self._bottom_data_shape, bottom_rois.dtype) 244 245 cuda.elementwise( 246 ''' 247 raw T top_diff, raw int32 argmax_data, 248 raw T bottom_rois, raw int32 bottom_roi_indices, int32 num_rois, 249 T spatial_scale, int32 channels, int32 height, int32 width, 250 int32 pooled_height, int32 pooled_width 251 ''', 252 'raw T bottom_diff', 253 ''' 254 int pw = i % pooled_width; 255 int ph = (i / pooled_width) % pooled_height; 256 int c = (i / pooled_width / pooled_height) % channels; 257 int n = i / pooled_width / pooled_height / channels; 258 259 int roi_batch_ind = bottom_roi_indices[n]; 260 int bottom_diff_offset = 261 (roi_batch_ind * channels + c) * height * width; 262 int top_diff_offset = 263 (n * channels + c) * pooled_height * pooled_width; 264 265 int max_index = 266 argmax_data[top_diff_offset + ph * pooled_width + pw]; 267 if (max_index != -1) { 268 atomicAdd( 269 &bottom_diff[bottom_diff_offset + max_index], 270 top_diff[top_diff_offset + ph * pooled_width + pw]); 271 } 272 ''', 'roi_max_pooling_2d_bwd' 273 )(gy[0], self.argmax_data, bottom_rois, bottom_roi_indices, 274 bottom_rois.shape[0], self.spatial_scale, channels, height, width, 275 self.outh, self.outw, bottom_diff, size=gy[0].size) 276 277 return bottom_diff, None, None 278 279 280def roi_max_pooling_2d(x, rois, roi_indices, outsize, spatial_scale): 281 """Spatial Region of Interest (ROI) max pooling function. 282 283 This function acts similarly to :func:`~chainer.functions.max_pooling_2d`, 284 but it computes the maximum of input spatial patch for each channel with 285 the region of interest. 286 287 Args: 288 x (~chainer.Variable): Input variable. The shape is expected to be 289 4 dimensional: (n: batch, c: channel, h, height, w: width). 290 rois (~chainer.Variable): Input roi variable. The shape is expected to 291 be (n: data size, 4), and each datum is set as below: 292 (y_min, x_min, y_max, x_max). 293 roi_indices (~chainer.Variable): Input roi variable. The shape is 294 expected to be (n: data size, ). 295 outsize ((int, int) or int): Expected output size after pooled 296 (height, width). ``outsize=o`` and ``outsize=(o, o)`` 297 are equivalent. 298 spatial_scale (float): Scale of the roi is resized. 299 300 Returns: 301 ~chainer.Variable: Output variable. 302 303 See the original paper proposing ROIPooling: 304 `Fast R-CNN <https://arxiv.org/abs/1504.08083>`_. 305 306 """ 307 return ROIMaxPooling2D(outsize, spatial_scale)(x, rois, roi_indices) 308