1 /*!
2  * Copyright (c) 2017 Microsoft
3  * Licensed under The Apache-2.0 License [see LICENSE for details]
4  * \file multi_proposal.cu
5  * \brief MultiProposal Operator
6  * \author Shaoqing Ren, Xizhou Zhu, Jian Guo
7 */
8 #include <dmlc/logging.h>
9 #include <dmlc/parameter.h>
10 #include <mxnet/operator.h>
11 #include <mshadow/tensor.h>
12 #include <mshadow/cuda/reduce.cuh>
13 #include <thrust/sort.h>
14 #include <thrust/execution_policy.h>
15 #include <thrust/functional.h>
16 
17 #include <map>
18 #include <vector>
19 #include <string>
20 #include <utility>
21 #include <ctime>
22 #include <iostream>
23 
24 #include "../operator_common.h"
25 #include "../mshadow_op.h"
26 #include "./multi_proposal-inl.h"
27 
28 #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
29 
30 #define FRCNN_CUDA_CHECK(condition) \
31   /* Code block avoids redefinition of cudaError_t error */ \
32   do { \
33     cudaError_t error = condition; \
34     CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
35 } while (0)
36 
37 namespace mshadow {
38 namespace cuda {
39 namespace multi_proposal {
40 
41 // scores are (b, 2 * anchor, h, w)
42 // workspace_proposals are (b, h * w * anchor, 5)
43 // w defines "x" and h defines "y"
44 // count should be total anchors numbers, h * w * anchors
45 template<typename Dtype>
ProposalGridKernel(const int count,const int num_anchors,const int height,const int width,const int feature_stride,const Dtype * scores,Dtype * workspace_proposals)46 __global__ void ProposalGridKernel(const int count,
47                                    const int num_anchors,
48                                    const int height,
49                                    const int width,
50                                    const int feature_stride,
51                                    const Dtype* scores,
52                                    Dtype* workspace_proposals) {
53   for (int index = blockIdx.x * blockDim.x + threadIdx.x;
54        index < count;
55        index += blockDim.x * gridDim.x) {
56     int a = index % num_anchors;
57     int w = (index / num_anchors) % width;
58     int h = (index / num_anchors / width) % height;
59     int b = index / num_anchors / width / height;
60 
61     workspace_proposals[index * 5 + 0] = workspace_proposals[a * 5 + 0] + w * feature_stride;
62     workspace_proposals[index * 5 + 1] = workspace_proposals[a * 5 + 1] + h * feature_stride;
63     workspace_proposals[index * 5 + 2] = workspace_proposals[a * 5 + 2] + w * feature_stride;
64     workspace_proposals[index * 5 + 3] = workspace_proposals[a * 5 + 3] + h * feature_stride;
65     workspace_proposals[index * 5 + 4] =
66         scores[((b * (2 * num_anchors) + a + num_anchors) * height + h) * width + w];
67   }
68 }
69 
70 // boxes are (b, h * w * anchor, 5)
71 // deltas are (b, 4 * anchor, h, w)
72 // out_pred_boxes are (b, h * w * anchor, 5)
73 // count should be total anchors numbers, b * h * w * anchors
74 // in-place write: boxes and out_pred_boxes are the same location
75 template<typename Dtype>
BBoxPredKernel(const int count,const int num_anchors,const int feat_height,const int feat_width,const int feature_stride,const Dtype * im_infos,const Dtype * boxes,const Dtype * deltas,Dtype * out_pred_boxes)76 __global__ void BBoxPredKernel(const int count,
77                                const int num_anchors,
78                                const int feat_height,
79                                const int feat_width,
80                                const int feature_stride,
81                                const Dtype* im_infos,
82                                const Dtype* boxes,
83                                const Dtype* deltas,
84                                Dtype* out_pred_boxes) {
85   for (int index = blockIdx.x * blockDim.x + threadIdx.x;
86        index < count;
87        index += blockDim.x * gridDim.x) {
88     int a = index % num_anchors;
89     int w = (index / num_anchors) % feat_width;
90     int h = (index / num_anchors / feat_width) % feat_height;
91     int b = index / num_anchors / feat_width / feat_height;
92 
93     float im_height = im_infos[b * 3];
94     float im_width = im_infos[b * 3 + 1];
95     int real_height = static_cast<int>(im_height / feature_stride);
96     int real_width = static_cast<int>(im_width / feature_stride);
97 
98     float width = boxes[index * 5 + 2] - boxes[index * 5 + 0] + 1.0f;
99     float height = boxes[index * 5 + 3] - boxes[index * 5 + 1] + 1.0f;
100     float ctr_x = boxes[index * 5 + 0] + 0.5f * (width - 1.0f);
101     float ctr_y = boxes[index * 5 + 1] + 0.5f * (height - 1.0f);
102 
103     int ba = (b * num_anchors + a);
104     float dx = deltas[((ba * 4) * feat_height + h) * feat_width + w];
105     float dy = deltas[((ba * 4 + 1) * feat_height + h) * feat_width + w];
106     float dw = deltas[((ba * 4 + 2) * feat_height + h) * feat_width + w];
107     float dh = deltas[((ba * 4 + 3) * feat_height + h) * feat_width + w];
108 
109     float pred_ctr_x = dx * width + ctr_x;
110     float pred_ctr_y = dy * height + ctr_y;
111     float pred_w = exp(dw) * width;
112     float pred_h = exp(dh) * height;
113 
114     float pred_x1 = pred_ctr_x - 0.5f * (pred_w - 1.0f);
115     float pred_y1 = pred_ctr_y - 0.5f * (pred_h - 1.0f);
116     float pred_x2 = pred_ctr_x + 0.5f * (pred_w - 1.0f);
117     float pred_y2 = pred_ctr_y + 0.5f * (pred_h - 1.0f);
118 
119     pred_x1 = max(min(pred_x1, im_width - 1.0f), 0.0f);
120     pred_y1 = max(min(pred_y1, im_height - 1.0f), 0.0f);
121     pred_x2 = max(min(pred_x2, im_width - 1.0f), 0.0f);
122     pred_y2 = max(min(pred_y2, im_height - 1.0f), 0.0f);
123 
124     out_pred_boxes[index * 5 + 0] = pred_x1;
125     out_pred_boxes[index * 5 + 1] = pred_y1;
126     out_pred_boxes[index * 5 + 2] = pred_x2;
127     out_pred_boxes[index * 5 + 3] = pred_y2;
128 
129     if (h >= real_height || w >= real_width) {
130       out_pred_boxes[index * 5 + 4] = -1.0f;
131     }
132   }
133 }
134 
135 // boxes are (b, h * w * anchor, 5)
136 // deltas are (b, 4 * anchor, h, w)
137 // out_pred_boxes are (b, h * w * anchor, 5)
138 // count should be total anchors numbers, b * h * w * anchors
139 // in-place write: boxes and out_pred_boxes are the same location
140 template<typename Dtype>
IoUPredKernel(const int count,const int num_anchors,const int feat_height,const int feat_width,const int feature_stride,const Dtype * im_infos,const Dtype * boxes,const Dtype * deltas,Dtype * out_pred_boxes)141 __global__ void IoUPredKernel(const int count,
142                               const int num_anchors,
143                               const int feat_height,
144                               const int feat_width,
145                               const int feature_stride,
146                               const Dtype* im_infos,
147                               const Dtype* boxes,
148                               const Dtype* deltas,
149                               Dtype* out_pred_boxes) {
150   for (int index = blockIdx.x * blockDim.x + threadIdx.x;
151        index < count;
152        index += blockDim.x * gridDim.x) {
153     int a = index % num_anchors;
154     int w = (index / num_anchors) % feat_width;
155     int h = (index / num_anchors / feat_width) % feat_height;
156     int b = index / num_anchors / feat_width / feat_height;
157 
158     float im_height = im_infos[b * 3];
159     float im_width = im_infos[b * 3 + 1];
160     int real_height = static_cast<int>(im_height / feature_stride);
161     int real_width = static_cast<int>(im_width / feature_stride);
162 
163     float x1 = boxes[index * 5 + 0];
164     float y1 = boxes[index * 5 + 1];
165     float x2 = boxes[index * 5 + 2];
166     float y2 = boxes[index * 5 + 3];
167 
168     int ba = (b * num_anchors + a);
169     float dx1 = deltas[((ba * 4) * feat_height + h) * feat_width + w];
170     float dy1 = deltas[((ba * 4 + 1) * feat_height + h) * feat_width + w];
171     float dx2 = deltas[((ba * 4 + 2) * feat_height + h) * feat_width + w];
172     float dy2 = deltas[((ba * 4 + 3) * feat_height + h) * feat_width + w];
173 
174     float pred_x1 = max(min(x1 + dx1, im_width - 1.0f), 0.0f);
175     float pred_y1 = max(min(y1 + dy1, im_height - 1.0f), 0.0f);
176     float pred_x2 = max(min(x2 + dx2, im_width - 1.0f), 0.0f);
177     float pred_y2 = max(min(y2 + dy2, im_height - 1.0f), 0.0f);
178 
179     out_pred_boxes[index * 5 + 0] = pred_x1;
180     out_pred_boxes[index * 5 + 1] = pred_y1;
181     out_pred_boxes[index * 5 + 2] = pred_x2;
182     out_pred_boxes[index * 5 + 3] = pred_y2;
183 
184     if (h >= real_height || w >= real_width) {
185       out_pred_boxes[index * 5 + 4] = -1.0f;
186     }
187   }
188 }
189 
190 // filter box with stride less than rpn_min_size
191 // filter: set score to zero
192 // dets (b, n, 5)
193 template<typename Dtype>
FilterBoxKernel(const int count,const int count_anchors,const float original_min_size,const Dtype * im_infos,Dtype * dets)194 __global__ void FilterBoxKernel(const int count,
195                                 const int count_anchors,
196                                 const float original_min_size,
197                                 const Dtype* im_infos,
198                                 Dtype* dets) {
199   for (int index = blockIdx.x * blockDim.x + threadIdx.x;
200        index < count;
201        index += blockDim.x * gridDim.x) {
202     int b = index / count_anchors;
203     float iw = dets[index * 5 + 2] - dets[index * 5 + 0] + 1.0f;
204     float ih = dets[index * 5 + 3] - dets[index * 5 + 1] + 1.0f;
205     float min_size = original_min_size * im_infos[b * 3 + 2];
206     if (iw < min_size || ih < min_size) {
207       dets[index * 5 + 0] -= min_size / 2;
208       dets[index * 5 + 1] -= min_size / 2;
209       dets[index * 5 + 2] += min_size / 2;
210       dets[index * 5 + 3] += min_size / 2;
211       dets[index * 5 + 4] = -1.0f;
212     }
213   }
214 }
215 
216 // copy score and init order
217 // dets (n, 5); score (n, ); order (n, )
218 // count should be n (total anchors or proposals)
219 template<typename Dtype>
CopyScoreKernel(const int count,const Dtype * dets,Dtype * score,int * order)220 __global__ void CopyScoreKernel(const int count,
221                                 const Dtype* dets,
222                                 Dtype* score,
223                                 int* order) {
224   for (int index = blockIdx.x * blockDim.x + threadIdx.x;
225        index < count;
226        index += blockDim.x * gridDim.x) {
227     score[index] = dets[index * 5 + 4];
228     order[index] = index;
229   }
230 }
231 
232 // reorder proposals according to order and keep the top_n proposals
233 // prev_dets (n, 5); order (n, ); dets (n, 5)
234 // count should be output anchor numbers (top_n)
235 template<typename Dtype>
ReorderProposalsKernel(const int count,const Dtype * prev_dets,const int * order,Dtype * dets)236 __global__ void ReorderProposalsKernel(const int count,
237                                        const Dtype* prev_dets,
238                                        const int* order,
239                                        Dtype* dets) {
240   for (int index = blockIdx.x * blockDim.x + threadIdx.x;
241        index < count;
242        index += blockDim.x * gridDim.x) {
243     const int order_i = order[index];
244     for (int j = 0; j < 5; j ++) {
245       dets[index * 5 + j] = prev_dets[order_i * 5 + j];
246     }
247   }
248 }
249 
devIoU(float const * const a,float const * const b)250 __device__ inline float devIoU(float const * const a, float const * const b) {
251   float left = max(a[0], b[0]), right = min(a[2], b[2]);
252   float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
253   float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
254   float interS = width * height;
255   float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
256   float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
257   return interS / (Sa + Sb - interS);
258 }
259 
nms_kernel(const int n_boxes,const float nms_overlap_thresh,const float * dev_boxes,uint64_t * dev_mask)260 __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
261                            const float *dev_boxes, uint64_t *dev_mask) {
262   const int threadsPerBlock = sizeof(uint64_t) * 8;
263   const int row_start = blockIdx.y;
264   const int col_start = blockIdx.x;
265 
266   // if (row_start > col_start) return;
267 
268   const int row_size =
269         min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
270   const int col_size =
271         min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
272 
273   __shared__ float block_boxes[threadsPerBlock * 5];
274   if (threadIdx.x < col_size) {
275     block_boxes[threadIdx.x * 5 + 0] =
276         dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
277     block_boxes[threadIdx.x * 5 + 1] =
278         dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
279     block_boxes[threadIdx.x * 5 + 2] =
280         dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
281     block_boxes[threadIdx.x * 5 + 3] =
282         dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
283     block_boxes[threadIdx.x * 5 + 4] =
284         dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
285   }
286   __syncthreads();
287 
288   if (threadIdx.x < row_size) {
289     const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
290     const float *cur_box = dev_boxes + cur_box_idx * 5;
291     int i = 0;
292     uint64_t t = 0;
293     int start = 0;
294     if (row_start == col_start) {
295       start = threadIdx.x + 1;
296     }
297     for (i = start; i < col_size; i++) {
298       if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
299         t |= 1ULL << i;
300       }
301     }
302     const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
303     dev_mask[cur_box_idx * col_blocks + col_start] = t;
304   }
305 }
306 
_nms(mshadow::Stream<gpu> * s,const mshadow::Tensor<gpu,2> & boxes,const float nms_overlap_thresh,const int rpn_post_nms_top_n,int * keep,int * num_out)307 void _nms(mshadow::Stream<gpu> *s,
308           const mshadow::Tensor<gpu, 2>& boxes,
309           const float nms_overlap_thresh,
310           const int rpn_post_nms_top_n,
311           int *keep,
312           int *num_out) {
313   const int threadsPerBlock = sizeof(uint64_t) * 8;
314   const int boxes_num = boxes.size(0);
315   const int boxes_dim = boxes.size(1);
316 
317   float* boxes_dev = boxes.dptr_;
318   uint64_t* mask_dev = nullptr;
319 
320   const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
321   FRCNN_CUDA_CHECK(cudaMalloc(&mask_dev,
322                               boxes_num * col_blocks * sizeof(uint64_t)));
323 
324   dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
325               DIVUP(boxes_num, threadsPerBlock));
326   dim3 threads(threadsPerBlock);
327   nms_kernel<<<blocks, threads>>>(boxes_num,
328                                   nms_overlap_thresh,
329                                   boxes_dev,
330                                   mask_dev);
331   FRCNN_CUDA_CHECK(cudaPeekAtLastError());
332   std::vector<uint64_t> mask_host(boxes_num * col_blocks);
333 
334   cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
335   FRCNN_CUDA_CHECK(cudaMemcpyAsync(&mask_host[0],
336                                    mask_dev,
337                                    sizeof(uint64_t) * boxes_num * col_blocks,
338                                    cudaMemcpyDeviceToHost, stream));
339   FRCNN_CUDA_CHECK(cudaStreamSynchronize(stream));
340 
341   std::vector<uint64_t> remv(col_blocks);
342   memset(&remv[0], 0, sizeof(uint64_t) * col_blocks);
343 
344   int num_to_keep = 0;
345   for (int i = 0; i < boxes_num; i++) {
346     int nblock = i / threadsPerBlock;
347     int inblock = i % threadsPerBlock;
348 
349     if (!(remv[nblock] & (1ULL << inblock))) {
350       keep[num_to_keep++] = i;
351       if (num_to_keep >= rpn_post_nms_top_n) break;
352       uint64_t *p = &mask_host[0] + i * col_blocks;
353       for (int j = nblock; j < col_blocks; j++) {
354         remv[j] |= p[j];
355       }
356     }
357   }
358   *num_out = num_to_keep;
359 
360   FRCNN_CUDA_CHECK(cudaFree(mask_dev));
361 }
362 
363 // copy proposals to output
364 // dets (top_n, 5); keep (top_n, ); out (top_n, )
365 // count should be top_n (total anchors or proposals)
366 template<typename Dtype>
PrepareOutput(const int count,const Dtype * dets,const int * keep,const int out_size,const int image_index,Dtype * out,Dtype * score)367 __global__ void PrepareOutput(const int count,
368                               const Dtype* dets,
369                               const int* keep,
370                               const int out_size,
371                               const int image_index,
372                               Dtype* out,
373                               Dtype* score) {
374   for (int index = blockIdx.x * blockDim.x + threadIdx.x;
375        index < count;
376        index += blockDim.x * gridDim.x) {
377     out[index * 5] = image_index;
378     if (index < out_size) {
379       int keep_i = keep[index];
380       for (int j = 0; j < 4; ++j) {
381         out[index * 5 + j + 1] = dets[keep_i * 5 + j];
382       }
383       score[index] = dets[keep_i * 5 + 4];
384     } else {
385       int keep_i = keep[index % out_size];
386       for (int j = 0; j < 4; ++j) {
387         out[index * 5 + j + 1] = dets[keep_i * 5 + j];
388       }
389       score[index] = dets[keep_i * 5 + 4];
390     }
391   }
392 }
393 }  // namespace multi_proposal
394 }  // namespace cuda
395 }  // namespace mshadow
396 
397 namespace mxnet {
398 namespace op {
399 
400 template<typename xpu>
401 class MultiProposalGPUOp : public Operator{
402  public:
MultiProposalGPUOp(MultiProposalParam param)403   explicit MultiProposalGPUOp(MultiProposalParam param) {
404     this->param_ = param;
405   }
406 
Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_states)407   virtual void Forward(const OpContext &ctx,
408                        const std::vector<TBlob> &in_data,
409                        const std::vector<OpReqType> &req,
410                        const std::vector<TBlob> &out_data,
411                        const std::vector<TBlob> &aux_states) {
412     using namespace mshadow;
413     using namespace mshadow::expr;
414     using namespace mshadow::cuda;
415     using namespace mshadow::cuda::multi_proposal;
416     CHECK_EQ(in_data.size(), 3);
417     CHECK_EQ(out_data.size(), 2);
418     CHECK_GT(req.size(), 1);
419     CHECK_EQ(req[proposal::kOut], kWriteTo);
420     /*CHECK_EQ(in_data[proposal::kClsProb].shape_[0], 1)
421       << "Sorry, multiple images each device is not implemented.";*/
422 
423     Stream<xpu> *s = ctx.get_stream<xpu>();
424 
425     Tensor<xpu, 4> scores = in_data[proposal::kClsProb].get<xpu, 4, real_t>(s);
426     Tensor<xpu, 4> bbox_deltas = in_data[proposal::kBBoxPred].get<xpu, 4, real_t>(s);
427     Tensor<xpu, 2> im_info = in_data[proposal::kImInfo].get<xpu, 2, real_t>(s);
428 
429     Tensor<xpu, 2> out = out_data[proposal::kOut].get<xpu, 2, real_t>(s);
430     Tensor<xpu, 2> out_score = out_data[proposal::kScore].get<xpu, 2, real_t>(s);
431 
432     int num_images = scores.size(0);
433     int num_anchors = scores.size(1) / 2;
434     int height = scores.size(2);
435     int width = scores.size(3);
436     int count_anchors = num_anchors * height * width;  // count of total anchors
437     int count = num_images * count_anchors;
438     // set to -1 for max
439     int rpn_pre_nms_top_n = (param_.rpn_pre_nms_top_n > 0) ? param_.rpn_pre_nms_top_n
440                                                            : count_anchors;
441     rpn_pre_nms_top_n = std::min(rpn_pre_nms_top_n, count_anchors);
442     int rpn_post_nms_top_n = std::min(param_.rpn_post_nms_top_n, rpn_pre_nms_top_n);
443 
444     // Generate first anchors based on base anchor
445     std::vector<float> base_anchor(4);
446     base_anchor[0] = 0.0;
447     base_anchor[1] = 0.0;
448     base_anchor[2] = param_.feature_stride - 1.0;
449     base_anchor[3] = param_.feature_stride - 1.0;
450     CHECK_EQ(num_anchors, param_.ratios.ndim() * param_.scales.ndim());
451     std::vector<float> anchors;
452     utils::GenerateAnchors(base_anchor,
453                            param_.ratios,
454                            param_.scales,
455                            &anchors);
456 
457     // Copy generated anchors to GPU
458     float* workspace_proposals_ptr = nullptr;
459     FRCNN_CUDA_CHECK(cudaMalloc(&workspace_proposals_ptr,
460                                 sizeof(float) * num_images * count_anchors * 5));
461     Tensor<xpu, 3> workspace_proposals(workspace_proposals_ptr,
462                                        Shape3(num_images, count_anchors, 5));
463 
464     cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
465 
466     FRCNN_CUDA_CHECK(cudaMemcpyAsync(workspace_proposals.dptr_, &anchors[0],
467                                      sizeof(float) * anchors.size(),
468                                      cudaMemcpyHostToDevice, stream));
469 
470     // Copy proposals to a mesh grid
471     dim3 dimGrid((count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock);
472     dim3 dimBlock(kMaxThreadsPerBlock);
473     CheckLaunchParam(dimGrid, dimBlock, "ProposalGrid");
474     ProposalGridKernel<<<dimGrid, dimBlock>>>(
475       count, num_anchors, height, width, param_.feature_stride,
476       scores.dptr_, workspace_proposals.dptr_);
477     FRCNN_CUDA_CHECK(cudaPeekAtLastError());
478 
479     // Transform anchors and bbox_deltas into bboxes
480     CheckLaunchParam(dimGrid, dimBlock, "BBoxPred");
481     if (param_.iou_loss) {
482       IoUPredKernel<<<dimGrid, dimBlock>>>(
483         count, num_anchors, height, width, param_.feature_stride, im_info.dptr_,
484         workspace_proposals.dptr_, bbox_deltas.dptr_, workspace_proposals.dptr_);
485     } else {
486       BBoxPredKernel<<<dimGrid, dimBlock>>>(
487         count, num_anchors, height, width, param_.feature_stride, im_info.dptr_,
488         workspace_proposals.dptr_, bbox_deltas.dptr_, workspace_proposals.dptr_);
489     }
490     FRCNN_CUDA_CHECK(cudaPeekAtLastError());
491 
492     // filter boxes with less than rpn_min_size
493     CheckLaunchParam(dimGrid, dimBlock, "FilterBox");
494     FilterBoxKernel<<<dimGrid, dimBlock>>>(
495       count, count_anchors, param_.rpn_min_size, im_info.dptr_, workspace_proposals.dptr_);
496     FRCNN_CUDA_CHECK(cudaPeekAtLastError());
497 
498 
499 
500     dimGrid = dim3((count_anchors + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock);
501     dimBlock = dim3(kMaxThreadsPerBlock);
502     // Copy score to a continuous memory
503     float* score_ptr = nullptr;
504     FRCNN_CUDA_CHECK(cudaMalloc(&score_ptr, sizeof(float) * count_anchors));
505     Tensor<xpu, 1> score(score_ptr, Shape1(count_anchors));
506     int* order_ptr = nullptr;
507     FRCNN_CUDA_CHECK(cudaMalloc(&order_ptr, sizeof(int) * count_anchors));
508     Tensor<xpu, 1, int> order(order_ptr, Shape1(count_anchors));
509 
510     float* workspace_ordered_proposals_ptr = nullptr;
511     FRCNN_CUDA_CHECK(cudaMalloc(&workspace_ordered_proposals_ptr,
512         sizeof(float) * rpn_pre_nms_top_n * 5));
513     Tensor<xpu, 2> workspace_ordered_proposals(workspace_ordered_proposals_ptr,
514         Shape2(rpn_pre_nms_top_n, 5));
515 
516     int* keep;
517     FRCNN_CUDA_CHECK(cudaMalloc(&keep, sizeof(int) * rpn_pre_nms_top_n));
518 
519     for (int b = 0; b < num_images; b++) {
520       CheckLaunchParam(dimGrid, dimBlock, "CopyScore");
521       CopyScoreKernel << <dimGrid, dimBlock >> >(
522           count_anchors, workspace_proposals.dptr_ + b * count_anchors * 5,
523           score.dptr_, order.dptr_);
524       FRCNN_CUDA_CHECK(cudaPeekAtLastError());
525 
526       // argsort score, save order
527       thrust::stable_sort_by_key(thrust::device,
528           score.dptr_,
529           score.dptr_ + score.size(0),
530           order.dptr_,
531           thrust::greater<real_t>());
532       FRCNN_CUDA_CHECK(cudaPeekAtLastError());
533 
534       // Reorder proposals according to order
535 
536       dimGrid.x = (rpn_pre_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
537       CheckLaunchParam(dimGrid, dimBlock, "ReorderProposals");
538       ReorderProposalsKernel << <dimGrid, dimBlock >> >(
539           rpn_pre_nms_top_n, workspace_proposals.dptr_ + b * count_anchors * 5,
540           order.dptr_, workspace_ordered_proposals.dptr_);
541       FRCNN_CUDA_CHECK(cudaPeekAtLastError());
542 
543       // perform nms
544       std::vector<int> _keep(workspace_ordered_proposals.size(0));
545       int out_size = 0;
546       _nms(s, workspace_ordered_proposals,
547            param_.threshold,
548            rpn_post_nms_top_n,
549            &_keep[0],
550            &out_size);
551 
552       // copy nms result to gpu
553       FRCNN_CUDA_CHECK(cudaMemcpyAsync(keep, &_keep[0], sizeof(int) * _keep.size(),
554                                        cudaMemcpyHostToDevice, stream));
555 
556       // copy results after nms
557       dimGrid.x = (param_.rpn_post_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
558       CheckLaunchParam(dimGrid, dimBlock, "PrepareOutput");
559       PrepareOutput << <dimGrid, dimBlock >> >(
560           param_.rpn_post_nms_top_n, workspace_ordered_proposals.dptr_, keep, out_size, b,
561           out.dptr_ + b * param_.rpn_post_nms_top_n * 5,
562           out_score.dptr_ + b * param_.rpn_post_nms_top_n);
563       FRCNN_CUDA_CHECK(cudaPeekAtLastError());
564     }
565     // free temporary memory
566     FRCNN_CUDA_CHECK(cudaFree(keep));
567     FRCNN_CUDA_CHECK(cudaFree(workspace_ordered_proposals_ptr));
568     FRCNN_CUDA_CHECK(cudaFree(workspace_proposals_ptr));
569     FRCNN_CUDA_CHECK(cudaFree(score_ptr));
570     FRCNN_CUDA_CHECK(cudaFree(order_ptr));
571   }
572 
Backward(const OpContext & ctx,const std::vector<TBlob> & out_grad,const std::vector<TBlob> & in_data,const std::vector<TBlob> & out_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & in_grad,const std::vector<TBlob> & aux_states)573   virtual void Backward(const OpContext &ctx,
574                         const std::vector<TBlob> &out_grad,
575                         const std::vector<TBlob> &in_data,
576                         const std::vector<TBlob> &out_data,
577                         const std::vector<OpReqType> &req,
578                         const std::vector<TBlob> &in_grad,
579                         const std::vector<TBlob> &aux_states) {
580     using namespace mshadow;
581     using namespace mshadow::expr;
582     CHECK_EQ(in_grad.size(), 3);
583 
584     Stream<xpu> *s = ctx.get_stream<xpu>();
585     Tensor<xpu, 4> gscores = in_grad[proposal::kClsProb].get<xpu, 4, real_t>(s);
586     Tensor<xpu, 4> gbbox = in_grad[proposal::kBBoxPred].get<xpu, 4, real_t>(s);
587     Tensor<xpu, 2> ginfo = in_grad[proposal::kImInfo].get<xpu, 2, real_t>(s);
588 
589     // can not assume the grad would be zero
590     Assign(gscores, req[proposal::kClsProb], 0);
591     Assign(gbbox, req[proposal::kBBoxPred], 0);
592     Assign(ginfo, req[proposal::kImInfo], 0);
593   }
594 
595  private:
596   MultiProposalParam param_;
597 };  // class MultiProposalGPUOp
598 
599 template<>
CreateOp(MultiProposalParam param)600 Operator* CreateOp<gpu>(MultiProposalParam param) {
601   return new MultiProposalGPUOp<gpu>(param);
602 }
603 }  // namespace op
604 }  // namespace mxnet
605