1 /*!
2 * Copyright (c) 2017 Microsoft
3 * Licensed under The Apache-2.0 License [see LICENSE for details]
4 * \file multi_proposal.cu
5 * \brief MultiProposal Operator
6 * \author Shaoqing Ren, Xizhou Zhu, Jian Guo
7 */
8 #include <dmlc/logging.h>
9 #include <dmlc/parameter.h>
10 #include <mxnet/operator.h>
11 #include <mshadow/tensor.h>
12 #include <mshadow/cuda/reduce.cuh>
13 #include <thrust/sort.h>
14 #include <thrust/execution_policy.h>
15 #include <thrust/functional.h>
16
17 #include <map>
18 #include <vector>
19 #include <string>
20 #include <utility>
21 #include <ctime>
22 #include <iostream>
23
24 #include "../operator_common.h"
25 #include "../mshadow_op.h"
26 #include "./multi_proposal-inl.h"
27
28 #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
29
30 #define FRCNN_CUDA_CHECK(condition) \
31 /* Code block avoids redefinition of cudaError_t error */ \
32 do { \
33 cudaError_t error = condition; \
34 CHECK_EQ(error, cudaSuccess) << " " << cudaGetErrorString(error); \
35 } while (0)
36
37 namespace mshadow {
38 namespace cuda {
39 namespace multi_proposal {
40
41 // scores are (b, 2 * anchor, h, w)
42 // workspace_proposals are (b, h * w * anchor, 5)
43 // w defines "x" and h defines "y"
44 // count should be total anchors numbers, h * w * anchors
45 template<typename Dtype>
ProposalGridKernel(const int count,const int num_anchors,const int height,const int width,const int feature_stride,const Dtype * scores,Dtype * workspace_proposals)46 __global__ void ProposalGridKernel(const int count,
47 const int num_anchors,
48 const int height,
49 const int width,
50 const int feature_stride,
51 const Dtype* scores,
52 Dtype* workspace_proposals) {
53 for (int index = blockIdx.x * blockDim.x + threadIdx.x;
54 index < count;
55 index += blockDim.x * gridDim.x) {
56 int a = index % num_anchors;
57 int w = (index / num_anchors) % width;
58 int h = (index / num_anchors / width) % height;
59 int b = index / num_anchors / width / height;
60
61 workspace_proposals[index * 5 + 0] = workspace_proposals[a * 5 + 0] + w * feature_stride;
62 workspace_proposals[index * 5 + 1] = workspace_proposals[a * 5 + 1] + h * feature_stride;
63 workspace_proposals[index * 5 + 2] = workspace_proposals[a * 5 + 2] + w * feature_stride;
64 workspace_proposals[index * 5 + 3] = workspace_proposals[a * 5 + 3] + h * feature_stride;
65 workspace_proposals[index * 5 + 4] =
66 scores[((b * (2 * num_anchors) + a + num_anchors) * height + h) * width + w];
67 }
68 }
69
70 // boxes are (b, h * w * anchor, 5)
71 // deltas are (b, 4 * anchor, h, w)
72 // out_pred_boxes are (b, h * w * anchor, 5)
73 // count should be total anchors numbers, b * h * w * anchors
74 // in-place write: boxes and out_pred_boxes are the same location
75 template<typename Dtype>
BBoxPredKernel(const int count,const int num_anchors,const int feat_height,const int feat_width,const int feature_stride,const Dtype * im_infos,const Dtype * boxes,const Dtype * deltas,Dtype * out_pred_boxes)76 __global__ void BBoxPredKernel(const int count,
77 const int num_anchors,
78 const int feat_height,
79 const int feat_width,
80 const int feature_stride,
81 const Dtype* im_infos,
82 const Dtype* boxes,
83 const Dtype* deltas,
84 Dtype* out_pred_boxes) {
85 for (int index = blockIdx.x * blockDim.x + threadIdx.x;
86 index < count;
87 index += blockDim.x * gridDim.x) {
88 int a = index % num_anchors;
89 int w = (index / num_anchors) % feat_width;
90 int h = (index / num_anchors / feat_width) % feat_height;
91 int b = index / num_anchors / feat_width / feat_height;
92
93 float im_height = im_infos[b * 3];
94 float im_width = im_infos[b * 3 + 1];
95 int real_height = static_cast<int>(im_height / feature_stride);
96 int real_width = static_cast<int>(im_width / feature_stride);
97
98 float width = boxes[index * 5 + 2] - boxes[index * 5 + 0] + 1.0f;
99 float height = boxes[index * 5 + 3] - boxes[index * 5 + 1] + 1.0f;
100 float ctr_x = boxes[index * 5 + 0] + 0.5f * (width - 1.0f);
101 float ctr_y = boxes[index * 5 + 1] + 0.5f * (height - 1.0f);
102
103 int ba = (b * num_anchors + a);
104 float dx = deltas[((ba * 4) * feat_height + h) * feat_width + w];
105 float dy = deltas[((ba * 4 + 1) * feat_height + h) * feat_width + w];
106 float dw = deltas[((ba * 4 + 2) * feat_height + h) * feat_width + w];
107 float dh = deltas[((ba * 4 + 3) * feat_height + h) * feat_width + w];
108
109 float pred_ctr_x = dx * width + ctr_x;
110 float pred_ctr_y = dy * height + ctr_y;
111 float pred_w = exp(dw) * width;
112 float pred_h = exp(dh) * height;
113
114 float pred_x1 = pred_ctr_x - 0.5f * (pred_w - 1.0f);
115 float pred_y1 = pred_ctr_y - 0.5f * (pred_h - 1.0f);
116 float pred_x2 = pred_ctr_x + 0.5f * (pred_w - 1.0f);
117 float pred_y2 = pred_ctr_y + 0.5f * (pred_h - 1.0f);
118
119 pred_x1 = max(min(pred_x1, im_width - 1.0f), 0.0f);
120 pred_y1 = max(min(pred_y1, im_height - 1.0f), 0.0f);
121 pred_x2 = max(min(pred_x2, im_width - 1.0f), 0.0f);
122 pred_y2 = max(min(pred_y2, im_height - 1.0f), 0.0f);
123
124 out_pred_boxes[index * 5 + 0] = pred_x1;
125 out_pred_boxes[index * 5 + 1] = pred_y1;
126 out_pred_boxes[index * 5 + 2] = pred_x2;
127 out_pred_boxes[index * 5 + 3] = pred_y2;
128
129 if (h >= real_height || w >= real_width) {
130 out_pred_boxes[index * 5 + 4] = -1.0f;
131 }
132 }
133 }
134
135 // boxes are (b, h * w * anchor, 5)
136 // deltas are (b, 4 * anchor, h, w)
137 // out_pred_boxes are (b, h * w * anchor, 5)
138 // count should be total anchors numbers, b * h * w * anchors
139 // in-place write: boxes and out_pred_boxes are the same location
140 template<typename Dtype>
IoUPredKernel(const int count,const int num_anchors,const int feat_height,const int feat_width,const int feature_stride,const Dtype * im_infos,const Dtype * boxes,const Dtype * deltas,Dtype * out_pred_boxes)141 __global__ void IoUPredKernel(const int count,
142 const int num_anchors,
143 const int feat_height,
144 const int feat_width,
145 const int feature_stride,
146 const Dtype* im_infos,
147 const Dtype* boxes,
148 const Dtype* deltas,
149 Dtype* out_pred_boxes) {
150 for (int index = blockIdx.x * blockDim.x + threadIdx.x;
151 index < count;
152 index += blockDim.x * gridDim.x) {
153 int a = index % num_anchors;
154 int w = (index / num_anchors) % feat_width;
155 int h = (index / num_anchors / feat_width) % feat_height;
156 int b = index / num_anchors / feat_width / feat_height;
157
158 float im_height = im_infos[b * 3];
159 float im_width = im_infos[b * 3 + 1];
160 int real_height = static_cast<int>(im_height / feature_stride);
161 int real_width = static_cast<int>(im_width / feature_stride);
162
163 float x1 = boxes[index * 5 + 0];
164 float y1 = boxes[index * 5 + 1];
165 float x2 = boxes[index * 5 + 2];
166 float y2 = boxes[index * 5 + 3];
167
168 int ba = (b * num_anchors + a);
169 float dx1 = deltas[((ba * 4) * feat_height + h) * feat_width + w];
170 float dy1 = deltas[((ba * 4 + 1) * feat_height + h) * feat_width + w];
171 float dx2 = deltas[((ba * 4 + 2) * feat_height + h) * feat_width + w];
172 float dy2 = deltas[((ba * 4 + 3) * feat_height + h) * feat_width + w];
173
174 float pred_x1 = max(min(x1 + dx1, im_width - 1.0f), 0.0f);
175 float pred_y1 = max(min(y1 + dy1, im_height - 1.0f), 0.0f);
176 float pred_x2 = max(min(x2 + dx2, im_width - 1.0f), 0.0f);
177 float pred_y2 = max(min(y2 + dy2, im_height - 1.0f), 0.0f);
178
179 out_pred_boxes[index * 5 + 0] = pred_x1;
180 out_pred_boxes[index * 5 + 1] = pred_y1;
181 out_pred_boxes[index * 5 + 2] = pred_x2;
182 out_pred_boxes[index * 5 + 3] = pred_y2;
183
184 if (h >= real_height || w >= real_width) {
185 out_pred_boxes[index * 5 + 4] = -1.0f;
186 }
187 }
188 }
189
190 // filter box with stride less than rpn_min_size
191 // filter: set score to zero
192 // dets (b, n, 5)
193 template<typename Dtype>
FilterBoxKernel(const int count,const int count_anchors,const float original_min_size,const Dtype * im_infos,Dtype * dets)194 __global__ void FilterBoxKernel(const int count,
195 const int count_anchors,
196 const float original_min_size,
197 const Dtype* im_infos,
198 Dtype* dets) {
199 for (int index = blockIdx.x * blockDim.x + threadIdx.x;
200 index < count;
201 index += blockDim.x * gridDim.x) {
202 int b = index / count_anchors;
203 float iw = dets[index * 5 + 2] - dets[index * 5 + 0] + 1.0f;
204 float ih = dets[index * 5 + 3] - dets[index * 5 + 1] + 1.0f;
205 float min_size = original_min_size * im_infos[b * 3 + 2];
206 if (iw < min_size || ih < min_size) {
207 dets[index * 5 + 0] -= min_size / 2;
208 dets[index * 5 + 1] -= min_size / 2;
209 dets[index * 5 + 2] += min_size / 2;
210 dets[index * 5 + 3] += min_size / 2;
211 dets[index * 5 + 4] = -1.0f;
212 }
213 }
214 }
215
216 // copy score and init order
217 // dets (n, 5); score (n, ); order (n, )
218 // count should be n (total anchors or proposals)
219 template<typename Dtype>
CopyScoreKernel(const int count,const Dtype * dets,Dtype * score,int * order)220 __global__ void CopyScoreKernel(const int count,
221 const Dtype* dets,
222 Dtype* score,
223 int* order) {
224 for (int index = blockIdx.x * blockDim.x + threadIdx.x;
225 index < count;
226 index += blockDim.x * gridDim.x) {
227 score[index] = dets[index * 5 + 4];
228 order[index] = index;
229 }
230 }
231
232 // reorder proposals according to order and keep the top_n proposals
233 // prev_dets (n, 5); order (n, ); dets (n, 5)
234 // count should be output anchor numbers (top_n)
235 template<typename Dtype>
ReorderProposalsKernel(const int count,const Dtype * prev_dets,const int * order,Dtype * dets)236 __global__ void ReorderProposalsKernel(const int count,
237 const Dtype* prev_dets,
238 const int* order,
239 Dtype* dets) {
240 for (int index = blockIdx.x * blockDim.x + threadIdx.x;
241 index < count;
242 index += blockDim.x * gridDim.x) {
243 const int order_i = order[index];
244 for (int j = 0; j < 5; j ++) {
245 dets[index * 5 + j] = prev_dets[order_i * 5 + j];
246 }
247 }
248 }
249
devIoU(float const * const a,float const * const b)250 __device__ inline float devIoU(float const * const a, float const * const b) {
251 float left = max(a[0], b[0]), right = min(a[2], b[2]);
252 float top = max(a[1], b[1]), bottom = min(a[3], b[3]);
253 float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f);
254 float interS = width * height;
255 float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1);
256 float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1);
257 return interS / (Sa + Sb - interS);
258 }
259
nms_kernel(const int n_boxes,const float nms_overlap_thresh,const float * dev_boxes,uint64_t * dev_mask)260 __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh,
261 const float *dev_boxes, uint64_t *dev_mask) {
262 const int threadsPerBlock = sizeof(uint64_t) * 8;
263 const int row_start = blockIdx.y;
264 const int col_start = blockIdx.x;
265
266 // if (row_start > col_start) return;
267
268 const int row_size =
269 min(n_boxes - row_start * threadsPerBlock, threadsPerBlock);
270 const int col_size =
271 min(n_boxes - col_start * threadsPerBlock, threadsPerBlock);
272
273 __shared__ float block_boxes[threadsPerBlock * 5];
274 if (threadIdx.x < col_size) {
275 block_boxes[threadIdx.x * 5 + 0] =
276 dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0];
277 block_boxes[threadIdx.x * 5 + 1] =
278 dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1];
279 block_boxes[threadIdx.x * 5 + 2] =
280 dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2];
281 block_boxes[threadIdx.x * 5 + 3] =
282 dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3];
283 block_boxes[threadIdx.x * 5 + 4] =
284 dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4];
285 }
286 __syncthreads();
287
288 if (threadIdx.x < row_size) {
289 const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x;
290 const float *cur_box = dev_boxes + cur_box_idx * 5;
291 int i = 0;
292 uint64_t t = 0;
293 int start = 0;
294 if (row_start == col_start) {
295 start = threadIdx.x + 1;
296 }
297 for (i = start; i < col_size; i++) {
298 if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) {
299 t |= 1ULL << i;
300 }
301 }
302 const int col_blocks = DIVUP(n_boxes, threadsPerBlock);
303 dev_mask[cur_box_idx * col_blocks + col_start] = t;
304 }
305 }
306
_nms(mshadow::Stream<gpu> * s,const mshadow::Tensor<gpu,2> & boxes,const float nms_overlap_thresh,const int rpn_post_nms_top_n,int * keep,int * num_out)307 void _nms(mshadow::Stream<gpu> *s,
308 const mshadow::Tensor<gpu, 2>& boxes,
309 const float nms_overlap_thresh,
310 const int rpn_post_nms_top_n,
311 int *keep,
312 int *num_out) {
313 const int threadsPerBlock = sizeof(uint64_t) * 8;
314 const int boxes_num = boxes.size(0);
315 const int boxes_dim = boxes.size(1);
316
317 float* boxes_dev = boxes.dptr_;
318 uint64_t* mask_dev = nullptr;
319
320 const int col_blocks = DIVUP(boxes_num, threadsPerBlock);
321 FRCNN_CUDA_CHECK(cudaMalloc(&mask_dev,
322 boxes_num * col_blocks * sizeof(uint64_t)));
323
324 dim3 blocks(DIVUP(boxes_num, threadsPerBlock),
325 DIVUP(boxes_num, threadsPerBlock));
326 dim3 threads(threadsPerBlock);
327 nms_kernel<<<blocks, threads>>>(boxes_num,
328 nms_overlap_thresh,
329 boxes_dev,
330 mask_dev);
331 FRCNN_CUDA_CHECK(cudaPeekAtLastError());
332 std::vector<uint64_t> mask_host(boxes_num * col_blocks);
333
334 cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
335 FRCNN_CUDA_CHECK(cudaMemcpyAsync(&mask_host[0],
336 mask_dev,
337 sizeof(uint64_t) * boxes_num * col_blocks,
338 cudaMemcpyDeviceToHost, stream));
339 FRCNN_CUDA_CHECK(cudaStreamSynchronize(stream));
340
341 std::vector<uint64_t> remv(col_blocks);
342 memset(&remv[0], 0, sizeof(uint64_t) * col_blocks);
343
344 int num_to_keep = 0;
345 for (int i = 0; i < boxes_num; i++) {
346 int nblock = i / threadsPerBlock;
347 int inblock = i % threadsPerBlock;
348
349 if (!(remv[nblock] & (1ULL << inblock))) {
350 keep[num_to_keep++] = i;
351 if (num_to_keep >= rpn_post_nms_top_n) break;
352 uint64_t *p = &mask_host[0] + i * col_blocks;
353 for (int j = nblock; j < col_blocks; j++) {
354 remv[j] |= p[j];
355 }
356 }
357 }
358 *num_out = num_to_keep;
359
360 FRCNN_CUDA_CHECK(cudaFree(mask_dev));
361 }
362
363 // copy proposals to output
364 // dets (top_n, 5); keep (top_n, ); out (top_n, )
365 // count should be top_n (total anchors or proposals)
366 template<typename Dtype>
PrepareOutput(const int count,const Dtype * dets,const int * keep,const int out_size,const int image_index,Dtype * out,Dtype * score)367 __global__ void PrepareOutput(const int count,
368 const Dtype* dets,
369 const int* keep,
370 const int out_size,
371 const int image_index,
372 Dtype* out,
373 Dtype* score) {
374 for (int index = blockIdx.x * blockDim.x + threadIdx.x;
375 index < count;
376 index += blockDim.x * gridDim.x) {
377 out[index * 5] = image_index;
378 if (index < out_size) {
379 int keep_i = keep[index];
380 for (int j = 0; j < 4; ++j) {
381 out[index * 5 + j + 1] = dets[keep_i * 5 + j];
382 }
383 score[index] = dets[keep_i * 5 + 4];
384 } else {
385 int keep_i = keep[index % out_size];
386 for (int j = 0; j < 4; ++j) {
387 out[index * 5 + j + 1] = dets[keep_i * 5 + j];
388 }
389 score[index] = dets[keep_i * 5 + 4];
390 }
391 }
392 }
393 } // namespace multi_proposal
394 } // namespace cuda
395 } // namespace mshadow
396
397 namespace mxnet {
398 namespace op {
399
400 template<typename xpu>
401 class MultiProposalGPUOp : public Operator{
402 public:
MultiProposalGPUOp(MultiProposalParam param)403 explicit MultiProposalGPUOp(MultiProposalParam param) {
404 this->param_ = param;
405 }
406
Forward(const OpContext & ctx,const std::vector<TBlob> & in_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & out_data,const std::vector<TBlob> & aux_states)407 virtual void Forward(const OpContext &ctx,
408 const std::vector<TBlob> &in_data,
409 const std::vector<OpReqType> &req,
410 const std::vector<TBlob> &out_data,
411 const std::vector<TBlob> &aux_states) {
412 using namespace mshadow;
413 using namespace mshadow::expr;
414 using namespace mshadow::cuda;
415 using namespace mshadow::cuda::multi_proposal;
416 CHECK_EQ(in_data.size(), 3);
417 CHECK_EQ(out_data.size(), 2);
418 CHECK_GT(req.size(), 1);
419 CHECK_EQ(req[proposal::kOut], kWriteTo);
420 /*CHECK_EQ(in_data[proposal::kClsProb].shape_[0], 1)
421 << "Sorry, multiple images each device is not implemented.";*/
422
423 Stream<xpu> *s = ctx.get_stream<xpu>();
424
425 Tensor<xpu, 4> scores = in_data[proposal::kClsProb].get<xpu, 4, real_t>(s);
426 Tensor<xpu, 4> bbox_deltas = in_data[proposal::kBBoxPred].get<xpu, 4, real_t>(s);
427 Tensor<xpu, 2> im_info = in_data[proposal::kImInfo].get<xpu, 2, real_t>(s);
428
429 Tensor<xpu, 2> out = out_data[proposal::kOut].get<xpu, 2, real_t>(s);
430 Tensor<xpu, 2> out_score = out_data[proposal::kScore].get<xpu, 2, real_t>(s);
431
432 int num_images = scores.size(0);
433 int num_anchors = scores.size(1) / 2;
434 int height = scores.size(2);
435 int width = scores.size(3);
436 int count_anchors = num_anchors * height * width; // count of total anchors
437 int count = num_images * count_anchors;
438 // set to -1 for max
439 int rpn_pre_nms_top_n = (param_.rpn_pre_nms_top_n > 0) ? param_.rpn_pre_nms_top_n
440 : count_anchors;
441 rpn_pre_nms_top_n = std::min(rpn_pre_nms_top_n, count_anchors);
442 int rpn_post_nms_top_n = std::min(param_.rpn_post_nms_top_n, rpn_pre_nms_top_n);
443
444 // Generate first anchors based on base anchor
445 std::vector<float> base_anchor(4);
446 base_anchor[0] = 0.0;
447 base_anchor[1] = 0.0;
448 base_anchor[2] = param_.feature_stride - 1.0;
449 base_anchor[3] = param_.feature_stride - 1.0;
450 CHECK_EQ(num_anchors, param_.ratios.ndim() * param_.scales.ndim());
451 std::vector<float> anchors;
452 utils::GenerateAnchors(base_anchor,
453 param_.ratios,
454 param_.scales,
455 &anchors);
456
457 // Copy generated anchors to GPU
458 float* workspace_proposals_ptr = nullptr;
459 FRCNN_CUDA_CHECK(cudaMalloc(&workspace_proposals_ptr,
460 sizeof(float) * num_images * count_anchors * 5));
461 Tensor<xpu, 3> workspace_proposals(workspace_proposals_ptr,
462 Shape3(num_images, count_anchors, 5));
463
464 cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
465
466 FRCNN_CUDA_CHECK(cudaMemcpyAsync(workspace_proposals.dptr_, &anchors[0],
467 sizeof(float) * anchors.size(),
468 cudaMemcpyHostToDevice, stream));
469
470 // Copy proposals to a mesh grid
471 dim3 dimGrid((count + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock);
472 dim3 dimBlock(kMaxThreadsPerBlock);
473 CheckLaunchParam(dimGrid, dimBlock, "ProposalGrid");
474 ProposalGridKernel<<<dimGrid, dimBlock>>>(
475 count, num_anchors, height, width, param_.feature_stride,
476 scores.dptr_, workspace_proposals.dptr_);
477 FRCNN_CUDA_CHECK(cudaPeekAtLastError());
478
479 // Transform anchors and bbox_deltas into bboxes
480 CheckLaunchParam(dimGrid, dimBlock, "BBoxPred");
481 if (param_.iou_loss) {
482 IoUPredKernel<<<dimGrid, dimBlock>>>(
483 count, num_anchors, height, width, param_.feature_stride, im_info.dptr_,
484 workspace_proposals.dptr_, bbox_deltas.dptr_, workspace_proposals.dptr_);
485 } else {
486 BBoxPredKernel<<<dimGrid, dimBlock>>>(
487 count, num_anchors, height, width, param_.feature_stride, im_info.dptr_,
488 workspace_proposals.dptr_, bbox_deltas.dptr_, workspace_proposals.dptr_);
489 }
490 FRCNN_CUDA_CHECK(cudaPeekAtLastError());
491
492 // filter boxes with less than rpn_min_size
493 CheckLaunchParam(dimGrid, dimBlock, "FilterBox");
494 FilterBoxKernel<<<dimGrid, dimBlock>>>(
495 count, count_anchors, param_.rpn_min_size, im_info.dptr_, workspace_proposals.dptr_);
496 FRCNN_CUDA_CHECK(cudaPeekAtLastError());
497
498
499
500 dimGrid = dim3((count_anchors + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock);
501 dimBlock = dim3(kMaxThreadsPerBlock);
502 // Copy score to a continuous memory
503 float* score_ptr = nullptr;
504 FRCNN_CUDA_CHECK(cudaMalloc(&score_ptr, sizeof(float) * count_anchors));
505 Tensor<xpu, 1> score(score_ptr, Shape1(count_anchors));
506 int* order_ptr = nullptr;
507 FRCNN_CUDA_CHECK(cudaMalloc(&order_ptr, sizeof(int) * count_anchors));
508 Tensor<xpu, 1, int> order(order_ptr, Shape1(count_anchors));
509
510 float* workspace_ordered_proposals_ptr = nullptr;
511 FRCNN_CUDA_CHECK(cudaMalloc(&workspace_ordered_proposals_ptr,
512 sizeof(float) * rpn_pre_nms_top_n * 5));
513 Tensor<xpu, 2> workspace_ordered_proposals(workspace_ordered_proposals_ptr,
514 Shape2(rpn_pre_nms_top_n, 5));
515
516 int* keep;
517 FRCNN_CUDA_CHECK(cudaMalloc(&keep, sizeof(int) * rpn_pre_nms_top_n));
518
519 for (int b = 0; b < num_images; b++) {
520 CheckLaunchParam(dimGrid, dimBlock, "CopyScore");
521 CopyScoreKernel << <dimGrid, dimBlock >> >(
522 count_anchors, workspace_proposals.dptr_ + b * count_anchors * 5,
523 score.dptr_, order.dptr_);
524 FRCNN_CUDA_CHECK(cudaPeekAtLastError());
525
526 // argsort score, save order
527 thrust::stable_sort_by_key(thrust::device,
528 score.dptr_,
529 score.dptr_ + score.size(0),
530 order.dptr_,
531 thrust::greater<real_t>());
532 FRCNN_CUDA_CHECK(cudaPeekAtLastError());
533
534 // Reorder proposals according to order
535
536 dimGrid.x = (rpn_pre_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
537 CheckLaunchParam(dimGrid, dimBlock, "ReorderProposals");
538 ReorderProposalsKernel << <dimGrid, dimBlock >> >(
539 rpn_pre_nms_top_n, workspace_proposals.dptr_ + b * count_anchors * 5,
540 order.dptr_, workspace_ordered_proposals.dptr_);
541 FRCNN_CUDA_CHECK(cudaPeekAtLastError());
542
543 // perform nms
544 std::vector<int> _keep(workspace_ordered_proposals.size(0));
545 int out_size = 0;
546 _nms(s, workspace_ordered_proposals,
547 param_.threshold,
548 rpn_post_nms_top_n,
549 &_keep[0],
550 &out_size);
551
552 // copy nms result to gpu
553 FRCNN_CUDA_CHECK(cudaMemcpyAsync(keep, &_keep[0], sizeof(int) * _keep.size(),
554 cudaMemcpyHostToDevice, stream));
555
556 // copy results after nms
557 dimGrid.x = (param_.rpn_post_nms_top_n + kMaxThreadsPerBlock - 1) / kMaxThreadsPerBlock;
558 CheckLaunchParam(dimGrid, dimBlock, "PrepareOutput");
559 PrepareOutput << <dimGrid, dimBlock >> >(
560 param_.rpn_post_nms_top_n, workspace_ordered_proposals.dptr_, keep, out_size, b,
561 out.dptr_ + b * param_.rpn_post_nms_top_n * 5,
562 out_score.dptr_ + b * param_.rpn_post_nms_top_n);
563 FRCNN_CUDA_CHECK(cudaPeekAtLastError());
564 }
565 // free temporary memory
566 FRCNN_CUDA_CHECK(cudaFree(keep));
567 FRCNN_CUDA_CHECK(cudaFree(workspace_ordered_proposals_ptr));
568 FRCNN_CUDA_CHECK(cudaFree(workspace_proposals_ptr));
569 FRCNN_CUDA_CHECK(cudaFree(score_ptr));
570 FRCNN_CUDA_CHECK(cudaFree(order_ptr));
571 }
572
Backward(const OpContext & ctx,const std::vector<TBlob> & out_grad,const std::vector<TBlob> & in_data,const std::vector<TBlob> & out_data,const std::vector<OpReqType> & req,const std::vector<TBlob> & in_grad,const std::vector<TBlob> & aux_states)573 virtual void Backward(const OpContext &ctx,
574 const std::vector<TBlob> &out_grad,
575 const std::vector<TBlob> &in_data,
576 const std::vector<TBlob> &out_data,
577 const std::vector<OpReqType> &req,
578 const std::vector<TBlob> &in_grad,
579 const std::vector<TBlob> &aux_states) {
580 using namespace mshadow;
581 using namespace mshadow::expr;
582 CHECK_EQ(in_grad.size(), 3);
583
584 Stream<xpu> *s = ctx.get_stream<xpu>();
585 Tensor<xpu, 4> gscores = in_grad[proposal::kClsProb].get<xpu, 4, real_t>(s);
586 Tensor<xpu, 4> gbbox = in_grad[proposal::kBBoxPred].get<xpu, 4, real_t>(s);
587 Tensor<xpu, 2> ginfo = in_grad[proposal::kImInfo].get<xpu, 2, real_t>(s);
588
589 // can not assume the grad would be zero
590 Assign(gscores, req[proposal::kClsProb], 0);
591 Assign(gbbox, req[proposal::kBBoxPred], 0);
592 Assign(ginfo, req[proposal::kImInfo], 0);
593 }
594
595 private:
596 MultiProposalParam param_;
597 }; // class MultiProposalGPUOp
598
599 template<>
CreateOp(MultiProposalParam param)600 Operator* CreateOp<gpu>(MultiProposalParam param) {
601 return new MultiProposalGPUOp<gpu>(param);
602 }
603 } // namespace op
604 } // namespace mxnet
605