1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file roi_pooling.cu
22  * \brief roi pooling operator
23  * \author Ross Girshick, Kye-Hyeon Kim, Jian Guo
24 */
25 #include "./roi_pooling-inl.h"
26 #include <mshadow/tensor.h>
27 #include <mshadow/cuda/reduce.cuh>
28 #include <algorithm>
29 #include <vector>
30 
31 namespace mshadow {
32 namespace cuda {
33 
34 template<typename Dtype>
ROIPoolForwardKernel(const int count,const Dtype * bottom_data,const float spatial_scale,const int batch_size,const int channels,const int height,const int width,const int pooled_height,const int pooled_width,const Dtype * bottom_rois,Dtype * top_data,index_t * argmax_data)35 __global__ void ROIPoolForwardKernel(const int count, const Dtype* bottom_data,
36                                      const float spatial_scale, const int batch_size,
37                                      const int channels, const int height, const int width,
38                                      const int pooled_height, const int pooled_width,
39                                      const Dtype* bottom_rois, Dtype* top_data,
40                                      index_t* argmax_data) {
41   for (index_t index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
42        index < count;
43        index += blockDim.x * gridDim.x * gridDim.y) {
44     // (n, c, ph, pw) is an element in the pooled output
45     int pw = index % pooled_width;
46     int ph = (index / pooled_width) % pooled_height;
47     int c = (index / pooled_width / pooled_height) % channels;
48     int n = index / pooled_width / pooled_height / channels;
49 
50     bottom_rois += n * 5;
51     int roi_batch_ind = static_cast<int>(bottom_rois[0]);
52 
53     if (roi_batch_ind < 0 || roi_batch_ind >= batch_size) {
54       top_data[index] = 0;
55       argmax_data[index] = -1;
56       continue;
57     }
58 
59     int roi_start_w = round(bottom_rois[1] * spatial_scale);
60     int roi_start_h = round(bottom_rois[2] * spatial_scale);
61     int roi_end_w = round(bottom_rois[3] * spatial_scale);
62     int roi_end_h = round(bottom_rois[4] * spatial_scale);
63 
64     // Force malformed ROIs to be 1x1
65     int roi_width = max(roi_end_w - roi_start_w + 1, 1);
66     int roi_height = max(roi_end_h - roi_start_h + 1, 1);
67     Dtype bin_size_h = static_cast<Dtype>(roi_height)
68                        / static_cast<Dtype>(pooled_height);
69     Dtype bin_size_w = static_cast<Dtype>(roi_width)
70                        / static_cast<Dtype>(pooled_width);
71 
72     int hstart = static_cast<int>(floor(static_cast<Dtype>(ph)
73                                         * bin_size_h));
74     int wstart = static_cast<int>(floor(static_cast<Dtype>(pw)
75                                         * bin_size_w));
76     int hend = static_cast<int>(ceil(static_cast<Dtype>(ph + 1)
77                                      * bin_size_h));
78     int wend = static_cast<int>(ceil(static_cast<Dtype>(pw + 1)
79                                      * bin_size_w));
80 
81     // Add roi offsets and clip to input boundaries
82     hstart = min(max(hstart + roi_start_h, 0), height);
83     hend = min(max(hend + roi_start_h, 0), height);
84     wstart = min(max(wstart + roi_start_w, 0), width);
85     wend = min(max(wend + roi_start_w, 0), width);
86     bool is_empty = (hend <= hstart) || (wend <= wstart);
87 
88     // Define an empty pooling region to be zero
89     Dtype maxval = is_empty ? 0 : -FLT_MAX;
90     // If nothing is pooled, argmax = -1 causes nothing to be backprop'd
91     index_t maxidx = -1;
92     index_t offset_bottom_data = (roi_batch_ind * channels + c) * height * width;
93     bottom_data += offset_bottom_data;
94     for (int h = hstart; h < hend; ++h) {
95       for (int w = wstart; w < wend; ++w) {
96         index_t bottom_index = h * width + w;
97         if (bottom_data[bottom_index] > maxval) {
98           maxval = bottom_data[bottom_index];
99           maxidx = offset_bottom_data + bottom_index;
100         }
101       }
102     }
103     top_data[index] = maxval;
104     argmax_data[index] = maxidx;
105   }
106 }
107 
108 template<typename Dtype>
ROIPoolForward(const Tensor<gpu,4,Dtype> & out,const Tensor<gpu,4,Dtype> & data,const Tensor<gpu,2,Dtype> & bbox,const Tensor<gpu,4,index_t> & max_idx,const float spatial_scale)109 inline void ROIPoolForward(const Tensor<gpu, 4, Dtype> &out,
110                            const Tensor<gpu, 4, Dtype> &data,
111                            const Tensor<gpu, 2, Dtype> &bbox,
112                            const Tensor<gpu, 4, index_t> &max_idx,
113                            const float spatial_scale) {
114   const Dtype *bottom_data = data.dptr_;
115   const Dtype *bottom_rois = bbox.dptr_;
116   Dtype *top_data = out.dptr_;
117   index_t *argmax_data = max_idx.dptr_;
118   const index_t count = out.shape_.Size();
119   const int batch_size = data.size(0);
120   const int channels = data.size(1);
121   const int height = data.size(2);
122   const int width = data.size(3);
123   const int pooled_height = out.size(2);
124   const int pooled_width = out.size(3);
125   const int gridSize = (count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
126   dim3 dimGrid(kMaxGridDim, (gridSize + kMaxGridDim - 1) / kMaxGridDim);
127   dim3 dimBlock(kMaxThreadsPerBlock);
128   CheckLaunchParam(dimGrid, dimBlock, "ROIPooling Forward");
129   cudaStream_t stream = Stream<gpu>::GetStream(out.stream_);
130   ROIPoolForwardKernel<Dtype><<<dimGrid, dimBlock, 0, stream>>>(
131       count, bottom_data, spatial_scale, batch_size, channels, height, width,
132       pooled_height, pooled_width, bottom_rois, top_data, argmax_data);
133   MSHADOW_CUDA_POST_KERNEL_CHECK(ROIPoolForwardKernel);
134 }
135 
136 template<typename Dtype>
ROIPoolBackwardAccKernel(const int count,const Dtype * top_diff,const index_t * argmax_data,Dtype * bottom_diff)137 __global__ void ROIPoolBackwardAccKernel(const int count, const Dtype* top_diff,
138                                          const index_t* argmax_data, Dtype* bottom_diff) {
139   for (index_t index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
140        index < count;
141        index += blockDim.x * gridDim.x * gridDim.y) {
142     index_t max_idx = argmax_data[index];
143     if (max_idx >= 0) {
144       atomicAdd(&bottom_diff[max_idx], top_diff[index]);
145     }
146   }
147 }
148 
149 template<typename Dtype>
ROIPoolBackwardAcc(const Tensor<gpu,4,Dtype> & in_grad,const Tensor<gpu,4,Dtype> & out_grad,const Tensor<gpu,2,Dtype> & bbox,const Tensor<gpu,4,index_t> & max_idx,const float spatial_scale)150 inline void ROIPoolBackwardAcc(const Tensor<gpu, 4, Dtype> &in_grad,
151                                const Tensor<gpu, 4, Dtype> &out_grad,
152                                const Tensor<gpu, 2, Dtype> &bbox,
153                                const Tensor<gpu, 4, index_t> &max_idx,
154                                const float spatial_scale) {
155   const Dtype *top_diff = out_grad.dptr_;
156   Dtype *bottom_diff = in_grad.dptr_;
157   index_t *argmax_data = max_idx.dptr_;
158   const index_t count = out_grad.shape_.Size();
159   const int gridSize = (count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
160   dim3 dimGrid(kMaxGridDim, (gridSize + kMaxGridDim - 1) / kMaxGridDim);
161   dim3 dimBlock(kMaxThreadsPerBlock);
162   CheckLaunchParam(dimGrid, dimBlock, "ROIPooling Backward");
163   cudaStream_t stream = Stream<gpu>::GetStream(in_grad.stream_);
164   ROIPoolBackwardAccKernel<Dtype><<<dimGrid, dimBlock, 0, stream>>>(
165       count, top_diff, argmax_data, bottom_diff);
166   MSHADOW_CUDA_POST_KERNEL_CHECK(ROIPoolBackwardAccKernel);
167 }
168 
169 }  // namespace cuda
170 
171 template<typename Dtype>
ROIPoolForward(const Tensor<gpu,4,Dtype> & out,const Tensor<gpu,4,Dtype> & data,const Tensor<gpu,2,Dtype> & bbox,const Tensor<gpu,4,index_t> & max_idx,const float spatial_scale)172 inline void ROIPoolForward(const Tensor<gpu, 4, Dtype> &out,
173                            const Tensor<gpu, 4, Dtype> &data,
174                            const Tensor<gpu, 2, Dtype> &bbox,
175                            const Tensor<gpu, 4, index_t> &max_idx,
176                            const float spatial_scale) {
177   cuda::ROIPoolForward(out, data, bbox, max_idx, spatial_scale);
178 }
179 
180 template<typename Dtype>
ROIPoolBackwardAcc(const Tensor<gpu,4,Dtype> & in_grad,const Tensor<gpu,4,Dtype> & out_grad,const Tensor<gpu,2,Dtype> & bbox,const Tensor<gpu,4,index_t> & max_idx,const float spatial_scale)181 inline void ROIPoolBackwardAcc(const Tensor<gpu, 4, Dtype> &in_grad,
182                                const Tensor<gpu, 4, Dtype> &out_grad,
183                                const Tensor<gpu, 2, Dtype> &bbox,
184                                const Tensor<gpu, 4, index_t> &max_idx,
185                                const float spatial_scale) {
186   cuda::ROIPoolBackwardAcc(in_grad, out_grad, bbox, max_idx, spatial_scale);
187 }
188 
189 }  // namespace mshadow
190 
191 
192 namespace mxnet {
193 namespace op {
194 
195 template<>
CreateOp(ROIPoolingParam param,int dtype)196 Operator* CreateOp<gpu>(ROIPoolingParam param, int dtype) {
197   Operator* op = nullptr;
198   MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
199     op = new ROIPoolingOp<gpu, DType>(param);
200   });
201   return op;
202 }
203 
204 }  // namespace op
205 }  // namespace mxnet
206