1# Modified work:
2# -----------------------------------------------------------------------------
3# Copyright (c) 2015 Preferred Infrastructure, Inc.
4# Copyright (c) 2015 Preferred Networks, Inc.
5# -----------------------------------------------------------------------------
6
7# Original work of _roi_pooling_slice, forward_cpu and backward_cpu:
8# -----------------------------------------------------------------------------
9# Copyright 2014 Nervana Systems Inc.
10# Licensed under the Apache License, Version 2.0 (the "License");
11# you may not use this file except in compliance with the License.
12# You may obtain a copy of the License at
13#
14#      https://www.apache.org/licenses/LICENSE-2.0
15#
16# Unless required by applicable law or agreed to in writing, software
17# distributed under the License is distributed on an "AS IS" BASIS,
18# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19# See the License for the specific language governing permissions and
20# limitations under the License.
21# -----------------------------------------------------------------------------
22
23# Original work of forward_gpu and backward_gpu:
24# -----------------------------------------------------------------------------
25# Fast R-CNN
26# Copyright (c) 2015 Microsoft
27# Licensed under The MIT License [see fast-rcnn/LICENSE for details]
28# Written by Ross Girshick
29# -----------------------------------------------------------------------------
30
31import numpy
32import six
33
34from chainer.backends import cuda
35from chainer import function_node
36from chainer.utils import type_check
37
38
39def _roi_pooling_slice(size, stride, max_size, roi_offset):
40    start = int(numpy.floor(size * stride))
41    end = int(numpy.ceil((size + 1) * stride))
42
43    start = min(max(start + roi_offset, 0), max_size)
44    end = min(max(end + roi_offset, 0), max_size)
45
46    return slice(start, end), end - start
47
48
49class ROIPooling2D(function_node.FunctionNode):
50
51    """RoI pooling over a set of 2d planes."""
52
53    def __init__(self, outh, outw, spatial_scale):
54        self.outh, self.outw = outh, outw
55        self.spatial_scale = spatial_scale
56
57    def check_type_forward(self, in_types):
58        type_check.expect(in_types.size() == 2)
59
60        x_type, roi_type = in_types
61        type_check.expect(
62            x_type.dtype.kind == 'f',
63            x_type.ndim == 4,
64            x_type.dtype == roi_type.dtype,
65            roi_type.ndim == 2,
66            roi_type.shape[1] == 5,
67        )
68
69    def forward_cpu(self, inputs):
70        self.retain_inputs((1,))
71        self._bottom_data_shape = inputs[0].shape
72
73        bottom_data, bottom_rois = inputs
74        channels, height, width = bottom_data.shape[1:]
75        n_rois = bottom_rois.shape[0]
76        # `numpy.zeros` needs to be used because the arrays can be
77        # returned without having some of its values updated.
78        top_data = numpy.zeros((n_rois, channels, self.outh, self.outw),
79                               dtype=bottom_data.dtype)
80        self.argmax_data = numpy.zeros(top_data.shape, numpy.int32)
81
82        for i_roi in six.moves.range(n_rois):
83            idx, xmin, ymin, xmax, ymax = bottom_rois[i_roi]
84            xmin = int(round(xmin * self.spatial_scale))
85            xmax = int(round(xmax * self.spatial_scale))
86            ymin = int(round(ymin * self.spatial_scale))
87            ymax = int(round(ymax * self.spatial_scale))
88            roi_width = max(xmax - xmin + 1, 1)
89            roi_height = max(ymax - ymin + 1, 1)
90            strideh = 1. * roi_height / self.outh
91            stridew = 1. * roi_width / self.outw
92
93            for outh in six.moves.range(self.outh):
94                sliceh, lenh = _roi_pooling_slice(
95                    outh, strideh, height, ymin)
96                if sliceh.stop <= sliceh.start:
97                    continue
98                for outw in six.moves.range(self.outw):
99                    slicew, lenw = _roi_pooling_slice(
100                        outw, stridew, width, xmin)
101                    if slicew.stop <= slicew.start:
102                        continue
103                    roi_data = bottom_data[int(idx), :, sliceh, slicew]\
104                        .reshape(channels, -1)
105                    top_data[i_roi, :, outh, outw] =\
106                        numpy.max(roi_data, axis=1)
107
108                    # get the max idx respect to feature_maps coordinates
109                    max_idx_slice = numpy.unravel_index(
110                        numpy.argmax(roi_data, axis=1), (lenh, lenw))
111                    max_idx_slice_h = max_idx_slice[0] + sliceh.start
112                    max_idx_slice_w = max_idx_slice[1] + slicew.start
113                    max_idx_slice = max_idx_slice_h * width + max_idx_slice_w
114                    self.argmax_data[i_roi, :, outh, outw] = max_idx_slice
115        return top_data,
116
117    def forward_gpu(self, inputs):
118        self.retain_inputs((1,))
119        self._bottom_data_shape = inputs[0].shape
120
121        bottom_data, bottom_rois = inputs
122        channels, height, width = bottom_data.shape[1:]
123        n_rois = bottom_rois.shape[0]
124        top_data = cuda.cupy.empty((n_rois, channels, self.outh,
125                                    self.outw), dtype=bottom_data.dtype)
126        self.argmax_data = cuda.cupy.empty(top_data.shape, numpy.int32)
127        cuda.elementwise(
128            '''
129            raw T bottom_data, T spatial_scale, int32 channels,
130            int32 height, int32 width, int32 pooled_height, int32 pooled_width,
131            raw T bottom_rois
132            ''',
133            'T top_data, int32 argmax_data',
134            '''
135            // pos in output filter
136            int pw = i % pooled_width;
137            int ph = (i / pooled_width) % pooled_height;
138            int c = (i / pooled_width / pooled_height) % channels;
139            int num = i / pooled_width / pooled_height / channels;
140
141            int roi_batch_ind = bottom_rois[num * 5 + 0];
142            int roi_start_w = round(bottom_rois[num * 5 + 1] * spatial_scale);
143            int roi_start_h = round(bottom_rois[num * 5 + 2] * spatial_scale);
144            int roi_end_w = round(bottom_rois[num * 5 + 3] * spatial_scale);
145            int roi_end_h = round(bottom_rois[num * 5 + 4] * spatial_scale);
146
147            // Force malformed ROIs to be 1x1
148            int roi_width = max(roi_end_w - roi_start_w + 1, 1);
149            int roi_height = max(roi_end_h - roi_start_h + 1, 1);
150            float bin_size_h = static_cast<float>(roi_height)
151                           / static_cast<float>(pooled_height);
152            float bin_size_w = static_cast<float>(roi_width)
153                           / static_cast<float>(pooled_width);
154
155            int hstart = static_cast<int>(floor(static_cast<float>(ph)
156                                          * bin_size_h));
157            int wstart = static_cast<int>(floor(static_cast<float>(pw)
158                                          * bin_size_w));
159            int hend = static_cast<int>(ceil(static_cast<float>(ph + 1)
160                                        * bin_size_h));
161            int wend = static_cast<int>(ceil(static_cast<float>(pw + 1)
162                                        * bin_size_w));
163
164            // Add roi offsets and clip to input boundaries
165            hstart = min(max(hstart + roi_start_h, 0), height);
166            hend = min(max(hend + roi_start_h, 0), height);
167            wstart = min(max(wstart + roi_start_w, 0), width);
168            wend = min(max(wend + roi_start_w, 0), width);
169            bool is_empty = (hend <= hstart) || (wend <= wstart);
170
171            // Define an empty pooling region to be zero
172            float maxval = is_empty ? 0 : -1E+37;
173            // If nothing is pooled, argmax=-1 causes nothing to be backprop'd
174            int maxidx = -1;
175            int data_offset = (roi_batch_ind * channels + c) * height * width;
176            for (int h = hstart; h < hend; ++h) {
177                for (int w = wstart; w < wend; ++w) {
178                    int bottom_index = h * width + w;
179                    if (bottom_data[data_offset + bottom_index] > maxval) {
180                        maxval = bottom_data[data_offset + bottom_index];
181                        maxidx = bottom_index;
182                    }
183                }
184            }
185            top_data = maxval;
186            argmax_data = maxidx;
187            ''', 'roi_pooling_2d_fwd'
188        )(bottom_data, self.spatial_scale, channels, height, width,
189          self.outh, self.outw, bottom_rois, top_data,
190          self.argmax_data)
191
192        return top_data,
193
194    def backward(self, indexes, grad_outputs):
195        bottom_rois, = self.get_retained_inputs()
196        gtop_data, = grad_outputs
197
198        f = ROIPooling2DGrad(self.outh, self.outw, self.spatial_scale,
199                             self._bottom_data_shape, self.argmax_data)
200        return f.apply((bottom_rois, gtop_data))
201
202
203class ROIPooling2DGrad(function_node.FunctionNode):
204
205    def __init__(self, outh, outw, spatial_scale, bottom_data_shape,
206                 argmax_data):
207        self.outh, self.outw = outh, outw
208        self.spatial_scale = spatial_scale
209        self._bottom_data_shape = bottom_data_shape
210        self.argmax_data = argmax_data
211
212    def forward_cpu(self, inputs):
213        bottom_rois, gtop_data = inputs
214        channels, height, width = self._bottom_data_shape[1:]
215        n_rois = bottom_rois.shape[0]
216        bottom_delta = numpy.zeros(self._bottom_data_shape, bottom_rois.dtype)
217
218        for i_roi in six.moves.range(n_rois):
219            idx, xmin, ymin, xmax, ymax = bottom_rois[i_roi]
220            idx = int(idx)
221            xmin = int(round(xmin * self.spatial_scale))
222            xmax = int(round(xmax * self.spatial_scale))
223            ymin = int(round(ymin * self.spatial_scale))
224            ymax = int(round(ymax * self.spatial_scale))
225            roi_width = max(xmax - xmin + 1, 1)
226            roi_height = max(ymax - ymin + 1, 1)
227
228            strideh = float(roi_height) / float(self.outh)
229            stridew = float(roi_width) / float(self.outw)
230
231            # iterate all the w, h (from feature map) that fall into this ROIs
232            for w in six.moves.range(xmin, xmax + 1):
233                for h in six.moves.range(ymin, ymax + 1):
234                    phstart = int(numpy.floor(float(h - ymin) / strideh))
235                    phend = int(numpy.ceil(float(h - ymin + 1) / strideh))
236                    pwstart = int(numpy.floor(float(w - xmin) / stridew))
237                    pwend = int(numpy.ceil(float(w - xmin + 1) / stridew))
238
239                    phstart = min(max(phstart, 0), self.outh)
240                    phend = min(max(phend, 0), self.outh)
241                    pwstart = min(max(pwstart, 0), self.outw)
242                    pwend = min(max(pwend, 0), self.outw)
243
244                    for ph in six.moves.range(phstart, phend):
245                        for pw in six.moves.range(pwstart, pwend):
246                            max_idx_tmp = self.argmax_data[i_roi, :, ph, pw]
247                            for c in six.moves.range(channels):
248                                if max_idx_tmp[c] == (h * width + w):
249                                    bottom_delta[idx, c, h, w] += \
250                                        gtop_data[i_roi, c, ph, pw]
251        return bottom_delta, None
252
253    def forward_gpu(self, inputs):
254        bottom_rois, gtop_data = inputs
255        channels, height, width = self._bottom_data_shape[1:]
256        bottom_diff = cuda.cupy.zeros(
257            self._bottom_data_shape, bottom_rois.dtype)
258
259        cuda.elementwise(
260            '''
261            raw T top_diff, raw int32 argmax_data, int32 num_rois,
262            T spatial_scale, int32 channels, int32 height, int32 width,
263            int32 pooled_height, int32 pooled_width, raw T bottom_rois
264            ''',
265            'T bottom_diff',
266            '''
267            int w = i % width;
268            int h = (i / width) % height;
269            int c = (i / (width * height)) % channels;
270            int num = i / (width * height * channels);
271
272            float gradient = 0;
273            // Accumulate gradient over all ROIs that pooled this element
274            for (int roi_n = 0; roi_n < num_rois; ++roi_n) {
275                // Skip if ROI's batch index doesn't match num
276                if (num != static_cast<int>(bottom_rois[roi_n * 5])) {
277                    continue;
278                }
279
280                int roi_start_w = round(bottom_rois[roi_n * 5 + 1]
281                                        * spatial_scale);
282                int roi_start_h = round(bottom_rois[roi_n * 5 + 2]
283                                        * spatial_scale);
284                int roi_end_w = round(bottom_rois[roi_n * 5 + 3]
285                                      * spatial_scale);
286                int roi_end_h = round(bottom_rois[roi_n * 5 + 4]
287                                      * spatial_scale);
288
289                // Skip if ROI doesn't include (h, w)
290                const bool in_roi = (w >= roi_start_w && w <= roi_end_w &&
291                                     h >= roi_start_h && h <= roi_end_h);
292                if (!in_roi) {
293                    continue;
294                }
295
296                int offset = (roi_n * channels + c) * pooled_height
297                             * pooled_width;
298
299                // Compute feasible set of pooled units that could have pooled
300                // this bottom unit
301
302                // Force malformed ROIs to be 1x1
303                int roi_width = max(roi_end_w - roi_start_w + 1, 1);
304                int roi_height = max(roi_end_h - roi_start_h + 1, 1);
305
306                float bin_size_h = static_cast<float>(roi_height)
307                               / static_cast<float>(pooled_height);
308                float bin_size_w = static_cast<float>(roi_width)
309                               / static_cast<float>(pooled_width);
310
311                int phstart = floor(static_cast<float>(h - roi_start_h)
312                                    / bin_size_h);
313                int phend = ceil(static_cast<float>(h - roi_start_h + 1)
314                                 / bin_size_h);
315                int pwstart = floor(static_cast<float>(w - roi_start_w)
316                                    / bin_size_w);
317                int pwend = ceil(static_cast<float>(w - roi_start_w + 1)
318                                 / bin_size_w);
319
320                phstart = min(max(phstart, 0), pooled_height);
321                phend = min(max(phend, 0), pooled_height);
322                pwstart = min(max(pwstart, 0), pooled_width);
323                pwend = min(max(pwend, 0), pooled_width);
324
325                for (int ph = phstart; ph < phend; ++ph) {
326                    for (int pw = pwstart; pw < pwend; ++pw) {
327                        int index_ = ph * pooled_width + pw + offset;
328                        if (argmax_data[index_] == (h * width + w)) {
329                            gradient += top_diff[index_];
330                        }
331                    }
332                }
333            }
334            bottom_diff = gradient;
335            ''', 'roi_pooling_2d_bwd'
336        )(gtop_data, self.argmax_data, bottom_rois.shape[0],
337          self.spatial_scale, channels, height, width, self.outh, self.outw,
338          bottom_rois, bottom_diff)
339
340        return bottom_diff, None
341
342    def backward(self, indexes, grad_outputs):
343        # No trivial way to implement double-backward for this function.
344        raise NotImplementedError
345
346
347def roi_pooling_2d(x, rois, outh, outw, spatial_scale):
348    """Spatial Region of Interest (ROI) pooling function.
349
350    This function acts similarly to :func:`~chainer.functions.max_pooling_2d`,
351    but it computes the maximum of input spatial patch for each channel with
352    the region of interest.
353
354    Args:
355        x (~chainer.Variable): Input variable. The shape is expected to be
356            4 dimensional: (n: batch, c: channel, h, height, w: width).
357        rois (~chainer.Variable): Input roi variable. The shape is expected to
358            be (n: data size, 5), and each datum is set as below:
359            (batch_index, x_min, y_min, x_max, y_max).
360        outh (int): Height of output image after pooled.
361        outw (int): Width of output image after pooled.
362        spatial_scale (float): Scale of the roi is resized.
363
364    Returns:
365        ~chainer.Variable: Output variable.
366
367    See the original paper proposing ROIPooling:
368    `Fast R-CNN <https://arxiv.org/abs/1504.08083>`_.
369
370    """
371    return ROIPooling2D(outh, outw, spatial_scale).apply((x, rois))[0]
372