1# Modified work: 2# ----------------------------------------------------------------------------- 3# Copyright (c) 2015 Preferred Infrastructure, Inc. 4# Copyright (c) 2015 Preferred Networks, Inc. 5# ----------------------------------------------------------------------------- 6 7# Original work of _roi_pooling_slice, forward_cpu and backward_cpu: 8# ----------------------------------------------------------------------------- 9# Copyright 2014 Nervana Systems Inc. 10# Licensed under the Apache License, Version 2.0 (the "License"); 11# you may not use this file except in compliance with the License. 12# You may obtain a copy of the License at 13# 14# https://www.apache.org/licenses/LICENSE-2.0 15# 16# Unless required by applicable law or agreed to in writing, software 17# distributed under the License is distributed on an "AS IS" BASIS, 18# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19# See the License for the specific language governing permissions and 20# limitations under the License. 21# ----------------------------------------------------------------------------- 22 23# Original work of forward_gpu and backward_gpu: 24# ----------------------------------------------------------------------------- 25# Fast R-CNN 26# Copyright (c) 2015 Microsoft 27# Licensed under The MIT License [see fast-rcnn/LICENSE for details] 28# Written by Ross Girshick 29# ----------------------------------------------------------------------------- 30 31import numpy 32import six 33 34from chainer.backends import cuda 35from chainer import function_node 36from chainer.utils import type_check 37 38 39def _roi_pooling_slice(size, stride, max_size, roi_offset): 40 start = int(numpy.floor(size * stride)) 41 end = int(numpy.ceil((size + 1) * stride)) 42 43 start = min(max(start + roi_offset, 0), max_size) 44 end = min(max(end + roi_offset, 0), max_size) 45 46 return slice(start, end), end - start 47 48 49class ROIPooling2D(function_node.FunctionNode): 50 51 """RoI pooling over a set of 2d planes.""" 52 53 def __init__(self, outh, outw, spatial_scale): 54 self.outh, self.outw = outh, outw 55 self.spatial_scale = spatial_scale 56 57 def check_type_forward(self, in_types): 58 type_check.expect(in_types.size() == 2) 59 60 x_type, roi_type = in_types 61 type_check.expect( 62 x_type.dtype.kind == 'f', 63 x_type.ndim == 4, 64 x_type.dtype == roi_type.dtype, 65 roi_type.ndim == 2, 66 roi_type.shape[1] == 5, 67 ) 68 69 def forward_cpu(self, inputs): 70 self.retain_inputs((1,)) 71 self._bottom_data_shape = inputs[0].shape 72 73 bottom_data, bottom_rois = inputs 74 channels, height, width = bottom_data.shape[1:] 75 n_rois = bottom_rois.shape[0] 76 # `numpy.zeros` needs to be used because the arrays can be 77 # returned without having some of its values updated. 78 top_data = numpy.zeros((n_rois, channels, self.outh, self.outw), 79 dtype=bottom_data.dtype) 80 self.argmax_data = numpy.zeros(top_data.shape, numpy.int32) 81 82 for i_roi in six.moves.range(n_rois): 83 idx, xmin, ymin, xmax, ymax = bottom_rois[i_roi] 84 xmin = int(round(xmin * self.spatial_scale)) 85 xmax = int(round(xmax * self.spatial_scale)) 86 ymin = int(round(ymin * self.spatial_scale)) 87 ymax = int(round(ymax * self.spatial_scale)) 88 roi_width = max(xmax - xmin + 1, 1) 89 roi_height = max(ymax - ymin + 1, 1) 90 strideh = 1. * roi_height / self.outh 91 stridew = 1. * roi_width / self.outw 92 93 for outh in six.moves.range(self.outh): 94 sliceh, lenh = _roi_pooling_slice( 95 outh, strideh, height, ymin) 96 if sliceh.stop <= sliceh.start: 97 continue 98 for outw in six.moves.range(self.outw): 99 slicew, lenw = _roi_pooling_slice( 100 outw, stridew, width, xmin) 101 if slicew.stop <= slicew.start: 102 continue 103 roi_data = bottom_data[int(idx), :, sliceh, slicew]\ 104 .reshape(channels, -1) 105 top_data[i_roi, :, outh, outw] =\ 106 numpy.max(roi_data, axis=1) 107 108 # get the max idx respect to feature_maps coordinates 109 max_idx_slice = numpy.unravel_index( 110 numpy.argmax(roi_data, axis=1), (lenh, lenw)) 111 max_idx_slice_h = max_idx_slice[0] + sliceh.start 112 max_idx_slice_w = max_idx_slice[1] + slicew.start 113 max_idx_slice = max_idx_slice_h * width + max_idx_slice_w 114 self.argmax_data[i_roi, :, outh, outw] = max_idx_slice 115 return top_data, 116 117 def forward_gpu(self, inputs): 118 self.retain_inputs((1,)) 119 self._bottom_data_shape = inputs[0].shape 120 121 bottom_data, bottom_rois = inputs 122 channels, height, width = bottom_data.shape[1:] 123 n_rois = bottom_rois.shape[0] 124 top_data = cuda.cupy.empty((n_rois, channels, self.outh, 125 self.outw), dtype=bottom_data.dtype) 126 self.argmax_data = cuda.cupy.empty(top_data.shape, numpy.int32) 127 cuda.elementwise( 128 ''' 129 raw T bottom_data, T spatial_scale, int32 channels, 130 int32 height, int32 width, int32 pooled_height, int32 pooled_width, 131 raw T bottom_rois 132 ''', 133 'T top_data, int32 argmax_data', 134 ''' 135 // pos in output filter 136 int pw = i % pooled_width; 137 int ph = (i / pooled_width) % pooled_height; 138 int c = (i / pooled_width / pooled_height) % channels; 139 int num = i / pooled_width / pooled_height / channels; 140 141 int roi_batch_ind = bottom_rois[num * 5 + 0]; 142 int roi_start_w = round(bottom_rois[num * 5 + 1] * spatial_scale); 143 int roi_start_h = round(bottom_rois[num * 5 + 2] * spatial_scale); 144 int roi_end_w = round(bottom_rois[num * 5 + 3] * spatial_scale); 145 int roi_end_h = round(bottom_rois[num * 5 + 4] * spatial_scale); 146 147 // Force malformed ROIs to be 1x1 148 int roi_width = max(roi_end_w - roi_start_w + 1, 1); 149 int roi_height = max(roi_end_h - roi_start_h + 1, 1); 150 float bin_size_h = static_cast<float>(roi_height) 151 / static_cast<float>(pooled_height); 152 float bin_size_w = static_cast<float>(roi_width) 153 / static_cast<float>(pooled_width); 154 155 int hstart = static_cast<int>(floor(static_cast<float>(ph) 156 * bin_size_h)); 157 int wstart = static_cast<int>(floor(static_cast<float>(pw) 158 * bin_size_w)); 159 int hend = static_cast<int>(ceil(static_cast<float>(ph + 1) 160 * bin_size_h)); 161 int wend = static_cast<int>(ceil(static_cast<float>(pw + 1) 162 * bin_size_w)); 163 164 // Add roi offsets and clip to input boundaries 165 hstart = min(max(hstart + roi_start_h, 0), height); 166 hend = min(max(hend + roi_start_h, 0), height); 167 wstart = min(max(wstart + roi_start_w, 0), width); 168 wend = min(max(wend + roi_start_w, 0), width); 169 bool is_empty = (hend <= hstart) || (wend <= wstart); 170 171 // Define an empty pooling region to be zero 172 float maxval = is_empty ? 0 : -1E+37; 173 // If nothing is pooled, argmax=-1 causes nothing to be backprop'd 174 int maxidx = -1; 175 int data_offset = (roi_batch_ind * channels + c) * height * width; 176 for (int h = hstart; h < hend; ++h) { 177 for (int w = wstart; w < wend; ++w) { 178 int bottom_index = h * width + w; 179 if (bottom_data[data_offset + bottom_index] > maxval) { 180 maxval = bottom_data[data_offset + bottom_index]; 181 maxidx = bottom_index; 182 } 183 } 184 } 185 top_data = maxval; 186 argmax_data = maxidx; 187 ''', 'roi_pooling_2d_fwd' 188 )(bottom_data, self.spatial_scale, channels, height, width, 189 self.outh, self.outw, bottom_rois, top_data, 190 self.argmax_data) 191 192 return top_data, 193 194 def backward(self, indexes, grad_outputs): 195 bottom_rois, = self.get_retained_inputs() 196 gtop_data, = grad_outputs 197 198 f = ROIPooling2DGrad(self.outh, self.outw, self.spatial_scale, 199 self._bottom_data_shape, self.argmax_data) 200 return f.apply((bottom_rois, gtop_data)) 201 202 203class ROIPooling2DGrad(function_node.FunctionNode): 204 205 def __init__(self, outh, outw, spatial_scale, bottom_data_shape, 206 argmax_data): 207 self.outh, self.outw = outh, outw 208 self.spatial_scale = spatial_scale 209 self._bottom_data_shape = bottom_data_shape 210 self.argmax_data = argmax_data 211 212 def forward_cpu(self, inputs): 213 bottom_rois, gtop_data = inputs 214 channels, height, width = self._bottom_data_shape[1:] 215 n_rois = bottom_rois.shape[0] 216 bottom_delta = numpy.zeros(self._bottom_data_shape, bottom_rois.dtype) 217 218 for i_roi in six.moves.range(n_rois): 219 idx, xmin, ymin, xmax, ymax = bottom_rois[i_roi] 220 idx = int(idx) 221 xmin = int(round(xmin * self.spatial_scale)) 222 xmax = int(round(xmax * self.spatial_scale)) 223 ymin = int(round(ymin * self.spatial_scale)) 224 ymax = int(round(ymax * self.spatial_scale)) 225 roi_width = max(xmax - xmin + 1, 1) 226 roi_height = max(ymax - ymin + 1, 1) 227 228 strideh = float(roi_height) / float(self.outh) 229 stridew = float(roi_width) / float(self.outw) 230 231 # iterate all the w, h (from feature map) that fall into this ROIs 232 for w in six.moves.range(xmin, xmax + 1): 233 for h in six.moves.range(ymin, ymax + 1): 234 phstart = int(numpy.floor(float(h - ymin) / strideh)) 235 phend = int(numpy.ceil(float(h - ymin + 1) / strideh)) 236 pwstart = int(numpy.floor(float(w - xmin) / stridew)) 237 pwend = int(numpy.ceil(float(w - xmin + 1) / stridew)) 238 239 phstart = min(max(phstart, 0), self.outh) 240 phend = min(max(phend, 0), self.outh) 241 pwstart = min(max(pwstart, 0), self.outw) 242 pwend = min(max(pwend, 0), self.outw) 243 244 for ph in six.moves.range(phstart, phend): 245 for pw in six.moves.range(pwstart, pwend): 246 max_idx_tmp = self.argmax_data[i_roi, :, ph, pw] 247 for c in six.moves.range(channels): 248 if max_idx_tmp[c] == (h * width + w): 249 bottom_delta[idx, c, h, w] += \ 250 gtop_data[i_roi, c, ph, pw] 251 return bottom_delta, None 252 253 def forward_gpu(self, inputs): 254 bottom_rois, gtop_data = inputs 255 channels, height, width = self._bottom_data_shape[1:] 256 bottom_diff = cuda.cupy.zeros( 257 self._bottom_data_shape, bottom_rois.dtype) 258 259 cuda.elementwise( 260 ''' 261 raw T top_diff, raw int32 argmax_data, int32 num_rois, 262 T spatial_scale, int32 channels, int32 height, int32 width, 263 int32 pooled_height, int32 pooled_width, raw T bottom_rois 264 ''', 265 'T bottom_diff', 266 ''' 267 int w = i % width; 268 int h = (i / width) % height; 269 int c = (i / (width * height)) % channels; 270 int num = i / (width * height * channels); 271 272 float gradient = 0; 273 // Accumulate gradient over all ROIs that pooled this element 274 for (int roi_n = 0; roi_n < num_rois; ++roi_n) { 275 // Skip if ROI's batch index doesn't match num 276 if (num != static_cast<int>(bottom_rois[roi_n * 5])) { 277 continue; 278 } 279 280 int roi_start_w = round(bottom_rois[roi_n * 5 + 1] 281 * spatial_scale); 282 int roi_start_h = round(bottom_rois[roi_n * 5 + 2] 283 * spatial_scale); 284 int roi_end_w = round(bottom_rois[roi_n * 5 + 3] 285 * spatial_scale); 286 int roi_end_h = round(bottom_rois[roi_n * 5 + 4] 287 * spatial_scale); 288 289 // Skip if ROI doesn't include (h, w) 290 const bool in_roi = (w >= roi_start_w && w <= roi_end_w && 291 h >= roi_start_h && h <= roi_end_h); 292 if (!in_roi) { 293 continue; 294 } 295 296 int offset = (roi_n * channels + c) * pooled_height 297 * pooled_width; 298 299 // Compute feasible set of pooled units that could have pooled 300 // this bottom unit 301 302 // Force malformed ROIs to be 1x1 303 int roi_width = max(roi_end_w - roi_start_w + 1, 1); 304 int roi_height = max(roi_end_h - roi_start_h + 1, 1); 305 306 float bin_size_h = static_cast<float>(roi_height) 307 / static_cast<float>(pooled_height); 308 float bin_size_w = static_cast<float>(roi_width) 309 / static_cast<float>(pooled_width); 310 311 int phstart = floor(static_cast<float>(h - roi_start_h) 312 / bin_size_h); 313 int phend = ceil(static_cast<float>(h - roi_start_h + 1) 314 / bin_size_h); 315 int pwstart = floor(static_cast<float>(w - roi_start_w) 316 / bin_size_w); 317 int pwend = ceil(static_cast<float>(w - roi_start_w + 1) 318 / bin_size_w); 319 320 phstart = min(max(phstart, 0), pooled_height); 321 phend = min(max(phend, 0), pooled_height); 322 pwstart = min(max(pwstart, 0), pooled_width); 323 pwend = min(max(pwend, 0), pooled_width); 324 325 for (int ph = phstart; ph < phend; ++ph) { 326 for (int pw = pwstart; pw < pwend; ++pw) { 327 int index_ = ph * pooled_width + pw + offset; 328 if (argmax_data[index_] == (h * width + w)) { 329 gradient += top_diff[index_]; 330 } 331 } 332 } 333 } 334 bottom_diff = gradient; 335 ''', 'roi_pooling_2d_bwd' 336 )(gtop_data, self.argmax_data, bottom_rois.shape[0], 337 self.spatial_scale, channels, height, width, self.outh, self.outw, 338 bottom_rois, bottom_diff) 339 340 return bottom_diff, None 341 342 def backward(self, indexes, grad_outputs): 343 # No trivial way to implement double-backward for this function. 344 raise NotImplementedError 345 346 347def roi_pooling_2d(x, rois, outh, outw, spatial_scale): 348 """Spatial Region of Interest (ROI) pooling function. 349 350 This function acts similarly to :func:`~chainer.functions.max_pooling_2d`, 351 but it computes the maximum of input spatial patch for each channel with 352 the region of interest. 353 354 Args: 355 x (~chainer.Variable): Input variable. The shape is expected to be 356 4 dimensional: (n: batch, c: channel, h, height, w: width). 357 rois (~chainer.Variable): Input roi variable. The shape is expected to 358 be (n: data size, 5), and each datum is set as below: 359 (batch_index, x_min, y_min, x_max, y_max). 360 outh (int): Height of output image after pooled. 361 outw (int): Width of output image after pooled. 362 spatial_scale (float): Scale of the roi is resized. 363 364 Returns: 365 ~chainer.Variable: Output variable. 366 367 See the original paper proposing ROIPooling: 368 `Fast R-CNN <https://arxiv.org/abs/1504.08083>`_. 369 370 """ 371 return ROIPooling2D(outh, outw, spatial_scale).apply((x, rois))[0] 372