1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20  /*!
21   * \file bounding_box.cu
22   * \brief Bounding box util functions and operators
23   * \author Joshua Zhang
24   */
25 
26 #include <cub/cub.cuh>
27 
28 #include "./bounding_box-inl.cuh"
29 #include "./bounding_box-inl.h"
30 #include "../elemwise_op_common.h"
31 
32 namespace mxnet {
33 namespace op {
34 
35 namespace {
36 
37 using mshadow::Tensor;
38 using mshadow::Stream;
39 
40 template <typename DType>
41 struct TempWorkspace {
42   index_t scores_temp_space;
43   DType* scores;
44   index_t scratch_space;
45   uint8_t* scratch;
46   index_t buffer_space;
47   DType* buffer;
48   index_t nms_scratch_space;
49   uint32_t* nms_scratch;
50   index_t indices_temp_spaces;
51   index_t* indices;
52 };
53 
ceil_div(index_t x,index_t y)54 inline index_t ceil_div(index_t x, index_t y) {
55   return (x + y - 1) / y;
56 }
57 
align(index_t x,index_t alignment)58 inline index_t align(index_t x, index_t alignment) {
59   return ceil_div(x, alignment)  * alignment;
60 }
61 
62 template <typename DType>
FilterAndPrepareAuxDataKernel(const DType * data,DType * out,DType * scores,index_t num_elements_per_batch,const index_t element_width,const index_t N,const float threshold,const int id_index,const int score_index,const int background_id)63 __global__ void FilterAndPrepareAuxDataKernel(const DType* data, DType* out, DType* scores,
64                                                index_t num_elements_per_batch,
65                                                const index_t element_width,
66                                                const index_t N,
67                                                const float threshold,
68                                                const int id_index, const int score_index,
69                                                const int background_id) {
70   index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
71   bool first_in_element = (tid % element_width == 0);
72   index_t start_of_my_element = tid - (tid % element_width);
73 
74   if (tid < N) {
75     DType my_score = data[start_of_my_element + score_index];
76     bool filtered_out = my_score <= threshold;
77     if (id_index != -1 && background_id != -1) {
78       DType my_id = data[start_of_my_element + id_index];
79       filtered_out = filtered_out || (my_id == background_id);
80     }
81     if (!filtered_out) {
82       out[tid] = data[tid];
83     } else {
84       out[tid] = -1;
85       my_score = -1;
86     }
87 
88     if (first_in_element) {
89       index_t offset = tid / element_width;
90       scores[offset] = my_score;
91     }
92   }
93 }
94 
95 template <typename DType>
FilterAndPrepareAuxData(const Tensor<gpu,3,DType> & data,Tensor<gpu,3,DType> * out,const TempWorkspace<DType> & workspace,const BoxNMSParam & param,Stream<gpu> * s)96 void FilterAndPrepareAuxData(const Tensor<gpu, 3, DType>& data,
97                              Tensor<gpu, 3, DType>* out,
98                              const TempWorkspace<DType>& workspace,
99                              const BoxNMSParam& param,
100                              Stream<gpu>* s) {
101   const int n_threads = 512;
102   index_t N = data.shape_.Size();
103   const auto blocks = ceil_div(N, n_threads);
104   FilterAndPrepareAuxDataKernel<<<blocks,
105                                    n_threads,
106                                    0,
107                                    Stream<gpu>::GetStream(s)>>>(
108     data.dptr_, out->dptr_, workspace.scores,
109     data.shape_[1], data.shape_[2], N,
110     param.valid_thresh, param.id_index,
111     param.score_index, param.background_id);
112 }
113 
114 template <bool check_topk, bool check_score, typename DType>
CompactDataKernel(const index_t * indices,const DType * source,DType * destination,const index_t topk,const index_t element_width,const index_t num_elements_per_batch,const int score_index,const index_t N)115 __global__ void CompactDataKernel(const index_t* indices, const DType* source,
116                                    DType* destination, const index_t topk,
117                                    const index_t element_width,
118                                    const index_t num_elements_per_batch,
119                                    const int score_index,
120                                    const index_t N) {
121   const index_t tid_start = blockIdx.x * blockDim.x + threadIdx.x;
122   for (index_t tid = tid_start; tid < N; tid += blockDim.x * gridDim.x) {
123     const index_t my_element = tid / element_width;
124     const index_t my_element_in_batch = my_element % num_elements_per_batch;
125     if (check_topk && my_element_in_batch >= topk) {
126       destination[tid] = -1;
127     } else {
128       DType ret;
129       const index_t source_element = indices[my_element];
130       DType score = 0;
131       if (check_score) {
132         score = source[source_element * element_width + score_index];
133       }
134       if (score >= 0) {
135         ret = source[source_element * element_width + tid % element_width];
136       } else {
137         ret = -1;
138       }
139       destination[tid] = ret;
140     }
141   }
142 }
143 
144 template <bool check_score, typename DType>
CompactData(const Tensor<gpu,1,index_t> & indices,const Tensor<gpu,3,DType> & source,Tensor<gpu,3,DType> * destination,const index_t topk,const int score_index,Stream<gpu> * s)145 void CompactData(const Tensor<gpu, 1, index_t>& indices,
146                  const Tensor<gpu, 3, DType>& source,
147                  Tensor<gpu, 3, DType>* destination,
148                  const index_t topk,
149                  const int score_index,
150                  Stream<gpu>* s) {
151   const int n_threads = 512;
152   const index_t max_blocks = 320;
153   index_t N = source.shape_.Size();
154   const auto blocks = std::min(ceil_div(N, n_threads), max_blocks);
155   if (topk > 0) {
156     CompactDataKernel<true, check_score><<<blocks, n_threads, 0,
157                                             Stream<gpu>::GetStream(s)>>>(
158       indices.dptr_, source.dptr_,
159       destination->dptr_, topk,
160       source.shape_[2], source.shape_[1],
161       score_index, N);
162   } else {
163     CompactDataKernel<false, check_score><<<blocks, n_threads, 0,
164                                              Stream<gpu>::GetStream(s)>>>(
165       indices.dptr_, source.dptr_,
166       destination->dptr_, topk,
167       source.shape_[2], source.shape_[1],
168       score_index, N);
169   }
170 }
171 
172 template <typename DType>
WorkspaceForSort(const index_t num_elem,const index_t topk,const int alignment,TempWorkspace<DType> * workspace)173 void WorkspaceForSort(const index_t num_elem,
174                       const index_t topk,
175                       const int alignment,
176                       TempWorkspace<DType>* workspace) {
177   const index_t sort_scores_temp_space =
178     mxnet::op::SortByKeyWorkspaceSize<DType, index_t, gpu>(num_elem, false, false);
179   const index_t sort_topk_scores_temp_space =
180     mxnet::op::SortByKeyWorkspaceSize<DType, index_t, gpu>(topk, false, false);
181   workspace->scratch_space = align(std::max(sort_scores_temp_space, sort_topk_scores_temp_space),
182                                    alignment);
183 }
184 
185 template <int encode, typename DType>
186 __global__ void CalculateGreedyNMSResultsKernel(const DType* data, uint32_t* result,
187                                                  const index_t current_start,
188                                                  const index_t num_elems,
189                                                  const index_t num_batches,
190                                                  const index_t num_blocks_per_row_batch,
191                                                  const index_t num_blocks_per_row,
192                                                  const index_t topk,
193                                                  const index_t element_width,
194                                                  const index_t num_elements_per_batch,
195                                                  const int coord_index,
196                                                  const int class_index,
197                                                  const int score_index,
198                                                  const float threshold);
199 
200 template <typename DType>
201 __global__ void ReduceNMSResultTriangleKernel(uint32_t* nms_results,
202                                                DType * data,
203                                                const index_t score_index,
204                                                const index_t element_width,
205                                                const index_t num_batches,
206                                                const index_t num_elems,
207                                                const index_t start_index,
208                                                const index_t topk);
209 
210 template <typename DType>
211 __global__ void ReduceNMSResultRestKernel(DType* data,
212                                            const uint32_t* nms_results,
213                                            const index_t score_index,
214                                            const index_t element_width,
215                                            const index_t num_batches,
216                                            const index_t num_elements_per_batch,
217                                            const index_t start_index,
218                                            const index_t topk,
219                                            const index_t num_blocks_per_batch);
220 
221 template <typename DType>
222 struct NMS {
223   static constexpr int THRESHOLD = 512;
224 
operator ()mxnet::op::__anon5cf8d4280111::NMS225   void operator()(Tensor<gpu, 3, DType>* data,
226                   Tensor<gpu, 2, uint32_t>* scratch,
227                   const index_t topk,
228                   const BoxNMSParam& param,
229                   Stream<gpu>* s) {
230     const int n_threads = 512;
231     const index_t num_batches = data->shape_[0];
232     const index_t num_elements_per_batch = data->shape_[1];
233     const index_t element_width = data->shape_[2];
234     for (index_t current_start = 0; current_start < topk; current_start += THRESHOLD) {
235       const index_t n_elems = topk - current_start;
236       const index_t num_blocks_per_row_batch = ceil_div(n_elems, n_threads);
237       const index_t num_blocks_per_row =  num_blocks_per_row_batch * num_batches;
238       const index_t n_blocks = THRESHOLD / (sizeof(uint32_t) * 8) * num_blocks_per_row;
239       if (param.in_format == box_common_enum::kCorner) {
240         CalculateGreedyNMSResultsKernel<box_common_enum::kCorner>
241           <<<n_blocks, n_threads, 0, Stream<gpu>::GetStream(s)>>>(
242             data->dptr_, scratch->dptr_, current_start, n_elems, num_batches,
243             num_blocks_per_row_batch, num_blocks_per_row, topk, element_width,
244             num_elements_per_batch, param.coord_start,
245             param.force_suppress ? -1 : param.id_index,
246             param.score_index, param.overlap_thresh);
247       } else {
248         CalculateGreedyNMSResultsKernel<box_common_enum::kCenter>
249           <<<n_blocks, n_threads, 0, Stream<gpu>::GetStream(s)>>>(
250             data->dptr_, scratch->dptr_, current_start, n_elems, num_batches,
251             num_blocks_per_row_batch, num_blocks_per_row, topk, element_width,
252             num_elements_per_batch, param.coord_start,
253             param.force_suppress ? -1 : param.id_index,
254             param.score_index, param.overlap_thresh);
255       }
256       ReduceNMSResultTriangleKernel<<<num_batches, THRESHOLD, 0, Stream<gpu>::GetStream(s)>>>(
257           scratch->dptr_, data->dptr_, param.score_index,
258           element_width, num_batches, num_elements_per_batch,
259           current_start, topk);
260       const index_t n_rest_elems = n_elems - THRESHOLD;
261       const index_t num_rest_blocks_per_batch = ceil_div(n_rest_elems, n_threads);
262       const index_t num_rest_blocks = num_rest_blocks_per_batch * num_batches;
263       if (n_rest_elems > 0) {
264         ReduceNMSResultRestKernel<<<num_rest_blocks, n_threads, 0, Stream<gpu>::GetStream(s)>>>(
265             data->dptr_, scratch->dptr_, param.score_index, element_width,
266             num_batches, num_elements_per_batch, current_start, topk,
267             num_rest_blocks_per_batch);
268       }
269     }
270   }
271 };
272 
273 template <int encode, typename DType>
calculate_area(const DType b0,const DType b1,const DType b2,const DType b3)274 __device__ __forceinline__ DType calculate_area(const DType b0, const DType b1,
275                                                 const DType b2, const DType b3) {
276   DType width = b2;
277   DType height = b3;
278   if (encode == box_common_enum::kCorner) {
279     width -= b0;
280     height -= b1;
281   }
282   if (width < 0 || height < 0) return 0;
283   return width * height;
284 }
285 
286 template <int encode, typename DType>
calculate_intersection(const DType a0,const DType a1,const DType a2,const DType a3,const DType b0,const DType b1,const DType b2,const DType b3)287 __device__ __forceinline__ DType calculate_intersection(const DType a0, const DType a1,
288                                                         const DType a2, const DType a3,
289                                                         const DType b0, const DType b1,
290                                                         const DType b2, const DType b3) {
291   DType wx, wy;
292   if (encode == box_common_enum::kCorner) {
293     const DType left = a0 > b0 ? a0 : b0;
294     const DType bottom = a1 > b1 ? a1 : b1;
295     const DType right = a2 < b2 ? a2 : b2;
296     const DType top = a3 < b3 ? a3 : b3;
297     wx = right - left;
298     wy = top - bottom;
299   } else {
300     const DType al = 2 * a0 - a2;
301     const DType ar = 2 * a0 + a2;
302     const DType bl = 2 * b0 - b2;
303     const DType br = 2 * b0 + b2;
304     const DType left = bl > al ? bl : al;
305     const DType right = br < ar ? br : ar;
306     wx = right - left;
307     const DType ab = 2 * a1 - a3;
308     const DType at = 2 * a1 + a3;
309     const DType bb = 2 * b1 - b3;
310     const DType bt = 2 * b1 + b3;
311     const DType bottom = bb > ab ? bb : ab;
312     const DType top = bt < at ? bt : at;
313     wy = top - bottom;
314     wy = wy / 4;  // To compensate for both wx and wy being 2x too large
315   }
316   if (wx <= 0 || wy <= 0) {
317     return 0;
318   } else {
319     return (wx * wy);
320   }
321 }
322 
323 template <int encode, typename DType>
324 __launch_bounds__(512)
CalculateGreedyNMSResultsKernel(const DType * data,uint32_t * result,const index_t current_start,const index_t num_elems,const index_t num_batches,const index_t num_blocks_per_row_batch,const index_t num_blocks_per_row,const index_t topk,const index_t element_width,const index_t num_elements_per_batch,const int coord_index,const int class_index,const int score_index,const float threshold)325 __global__ void CalculateGreedyNMSResultsKernel(const DType* data, uint32_t* result,
326                                                  const index_t current_start,
327                                                  const index_t num_elems,
328                                                  const index_t num_batches,
329                                                  const index_t num_blocks_per_row_batch,
330                                                  const index_t num_blocks_per_row,
331                                                  const index_t topk,
332                                                  const index_t element_width,
333                                                  const index_t num_elements_per_batch,
334                                                  const int coord_index,
335                                                  const int class_index,
336                                                  const int score_index,
337                                                  const float threshold) {
338   constexpr int max_elem_width = 20;
339   constexpr int num_other_boxes = sizeof(uint32_t) * 8;
340   __shared__ DType other_boxes[max_elem_width * num_other_boxes];
341   __shared__ DType other_boxes_areas[num_other_boxes];
342   const index_t my_row = blockIdx.x / num_blocks_per_row;
343   const index_t my_block_offset_in_row = blockIdx.x % num_blocks_per_row;
344   const index_t my_block_offset_in_batch = my_block_offset_in_row % num_blocks_per_row_batch;
345   const index_t my_batch = (my_block_offset_in_row) / num_blocks_per_row_batch;
346   const index_t my_element_in_batch = my_block_offset_in_batch * blockDim.x +
347                                       current_start + threadIdx.x;
348 
349   // Load other boxes
350   const index_t offset = (my_batch * num_elements_per_batch +
351                          current_start + my_row * num_other_boxes) *
352                          element_width;
353   for (int i = threadIdx.x; i < element_width * num_other_boxes; i += blockDim.x) {
354     other_boxes[i] = data[offset + i];
355   }
356   __syncthreads();
357 
358   if (threadIdx.x < num_other_boxes) {
359     const int other_boxes_offset = element_width * threadIdx.x;
360     const DType their_area = calculate_area<encode>(
361         other_boxes[other_boxes_offset + coord_index + 0],
362         other_boxes[other_boxes_offset + coord_index + 1],
363         other_boxes[other_boxes_offset + coord_index + 2],
364         other_boxes[other_boxes_offset + coord_index + 3]);
365     other_boxes_areas[threadIdx.x] = their_area;
366   }
367   __syncthreads();
368 
369   if (my_element_in_batch >= topk) return;
370 
371   DType my_box[4];
372   DType my_class = -1;
373   DType my_score = -1;
374   const index_t my_offset = (my_batch * num_elements_per_batch + my_element_in_batch) *
375                             element_width;
376   my_score = data[my_offset + score_index];
377 #pragma unroll
378   for (int i = 0; i < 4; ++i) {
379     my_box[i] = data[my_offset + coord_index + i];
380   }
381   if (class_index != -1) {
382     my_class = data[my_offset + class_index];
383   }
384   DType my_area = calculate_area<encode>(my_box[0], my_box[1], my_box[2], my_box[3]);
385 
386   uint32_t ret = 0;
387   if (my_score != -1) {
388 #pragma unroll
389     for (int i = 0; i < num_other_boxes; ++i) {
390       const int other_boxes_offset = element_width * i;
391       if ((class_index == -1 || my_class == other_boxes[other_boxes_offset + class_index]) &&
392           other_boxes[other_boxes_offset + score_index] != -1) {
393         const DType their_area = other_boxes_areas[i];
394 
395         const DType intersect = calculate_intersection<encode>(
396             my_box[0], my_box[1], my_box[2], my_box[3],
397             other_boxes[other_boxes_offset + coord_index + 0],
398             other_boxes[other_boxes_offset + coord_index + 1],
399             other_boxes[other_boxes_offset + coord_index + 2],
400             other_boxes[other_boxes_offset + coord_index + 3]);
401         if (intersect > threshold * (my_area + their_area - intersect)) {
402           ret = ret | (1u << i);
403         }
404       }
405     }
406   }
407   result[(my_row * num_batches + my_batch) * topk + my_element_in_batch] = ~ret;
408 }
409 
410 template <typename DType>
__launch_bounds__(NMS<DType>::THRESHOLD)411 __launch_bounds__(NMS<DType>::THRESHOLD)
412 __global__ void ReduceNMSResultTriangleKernel(uint32_t* nms_results,
413                                                DType * data,
414                                                const index_t score_index,
415                                                const index_t element_width,
416                                                const index_t num_batches,
417                                                const index_t num_elements_per_batch,
418                                                const index_t start_index,
419                                                const index_t topk) {
420   constexpr int n_threads = NMS<DType>::THRESHOLD;
421   constexpr int warp_size = 32;
422   const index_t my_batch = blockIdx.x;
423   const index_t my_element_in_batch = threadIdx.x + start_index;
424   const index_t my_element = my_batch * topk + my_element_in_batch;
425   const int my_warp = threadIdx.x / warp_size;
426   const int my_lane = threadIdx.x % warp_size;
427 
428   __shared__ uint32_t current_valid_boxes[n_threads / warp_size];
429   const uint32_t full_mask = 0xFFFFFFFF;
430   const uint32_t my_lane_mask = 1 << my_lane;
431   const uint32_t earlier_threads_mask = (1 << (my_lane + 1)) - 1;
432   uint32_t valid = my_lane_mask;
433   uint32_t valid_boxes = full_mask;
434 
435   uint32_t my_next_mask = my_element_in_batch < topk ?
436     nms_results[my_element]:
437     full_mask;
438 #pragma unroll
439   for (int i = 0; i < n_threads / warp_size; ++i) {
440     uint32_t my_mask = my_next_mask;
441     my_next_mask = (((i + 1) < n_threads / warp_size) &&
442                     (my_element_in_batch < topk)) ?
443       nms_results[(i + 1) * topk * num_batches + my_element]:
444       full_mask;
445     if (my_warp == i && !__all_sync(full_mask, my_mask == full_mask)) {
446       my_mask = my_mask | earlier_threads_mask;
447       // Loop over warp_size - 1 because the last
448       // thread does not contribute to the mask anyway
449 #pragma unroll
450       for (int j = 0; j < warp_size - 1; ++j) {
451           const uint32_t mask = __shfl_sync(full_mask, valid ? my_mask : full_mask, j);
452           valid = valid & mask;
453       }
454       valid_boxes = __ballot_sync(full_mask, valid);
455     }
456     if (my_lane == 0 && my_warp == i) {
457       current_valid_boxes[i] = valid_boxes;
458     }
459     __syncthreads();
460     if ((my_warp > i) && (((~my_mask) & current_valid_boxes[i]) != 0)) {
461       valid = 0;
462     }
463   }
464   if (my_lane == 0) {
465     nms_results[my_element] = valid_boxes;
466   }
467   if (valid == 0) {
468     data[(my_batch * num_elements_per_batch + my_element_in_batch) * element_width +
469          score_index] = -1;
470   }
471 }
472 
473 template <typename DType>
474 __launch_bounds__(512)
ReduceNMSResultRestKernel(DType * data,const uint32_t * nms_results,const index_t score_index,const index_t element_width,const index_t num_batches,const index_t num_elements_per_batch,const index_t start_index,const index_t topk,const index_t num_blocks_per_batch)475 __global__ void ReduceNMSResultRestKernel(DType* data,
476                                            const uint32_t* nms_results,
477                                            const index_t score_index,
478                                            const index_t element_width,
479                                            const index_t num_batches,
480                                            const index_t num_elements_per_batch,
481                                            const index_t start_index,
482                                            const index_t topk,
483                                            const index_t num_blocks_per_batch) {
484   constexpr int num_other_boxes = sizeof(uint32_t) * 8;
485   constexpr int num_iterations = NMS<DType>::THRESHOLD / num_other_boxes;
486   constexpr int warp_size = 32;
487   const index_t my_block_offset_in_batch = blockIdx.x % num_blocks_per_batch;
488   const index_t my_batch = blockIdx.x / num_blocks_per_batch;
489   const index_t my_element_in_batch = my_block_offset_in_batch * blockDim.x +
490                                       start_index + NMS<DType>::THRESHOLD + threadIdx.x;
491   const index_t my_element = my_batch * topk + my_element_in_batch;
492 
493   if (my_element_in_batch >= topk) return;
494 
495   bool valid = true;
496 
497 #pragma unroll
498   for (int i = 0; i < num_iterations; ++i) {
499     const uint32_t my_mask = nms_results[i * topk * num_batches + my_element];
500     const uint32_t valid_boxes = nms_results[my_batch * topk + i * warp_size + start_index];
501 
502     const bool no_hit = (valid_boxes & (~my_mask)) == 0;
503     valid = valid && no_hit;
504   }
505 
506   if (!valid) {
507     data[(my_batch * num_elements_per_batch + my_element_in_batch) * element_width +
508           score_index] = -1;
509   }
510 }
511 
512 template <typename DType>
GetWorkspace(const index_t num_batch,const index_t num_elem,const int width_elem,const index_t topk,const OpContext & ctx)513 TempWorkspace<DType> GetWorkspace(const index_t num_batch,
514                                   const index_t num_elem,
515                                   const int width_elem,
516                                   const index_t topk,
517                                   const OpContext& ctx) {
518   TempWorkspace<DType> workspace;
519   Stream<gpu> *s = ctx.get_stream<gpu>();
520   const int alignment = 128;
521 
522   // Get the workspace size
523   workspace.scores_temp_space = 2 * align(num_batch * num_elem * sizeof(DType), alignment);
524   workspace.indices_temp_spaces = 2 * align(num_batch * num_elem * sizeof(index_t), alignment);
525   WorkspaceForSort(num_elem, topk, alignment, &workspace);
526   // Place for a buffer
527   workspace.buffer_space = align(num_batch * num_elem * width_elem * sizeof(DType), alignment);
528   workspace.nms_scratch_space = align(NMS<DType>::THRESHOLD / (sizeof(uint32_t) * 8) *
529                                       num_batch * topk * sizeof(uint32_t), alignment);
530 
531   const index_t workspace_size = workspace.scores_temp_space +
532                                  workspace.scratch_space +
533                                  workspace.nms_scratch_space +
534                                  workspace.indices_temp_spaces;
535 
536   // Obtain the memory for workspace
537   Tensor<gpu, 1, uint8_t> scratch_memory = ctx.requested[box_nms_enum::kTempSpace]
538     .get_space_typed<gpu, 1, uint8_t>(mshadow::Shape1(workspace_size), s);
539 
540   // Populate workspace pointers
541   workspace.scores = reinterpret_cast<DType*>(scratch_memory.dptr_);
542   workspace.scratch = reinterpret_cast<uint8_t*>(workspace.scores) +
543                                                  workspace.scores_temp_space;
544   workspace.buffer = reinterpret_cast<DType*>(workspace.scratch +
545                                               workspace.scratch_space);
546   workspace.nms_scratch = reinterpret_cast<uint32_t*>(
547                             reinterpret_cast<uint8_t*>(workspace.buffer) +
548                             workspace.buffer_space);
549   workspace.indices = reinterpret_cast<index_t*>(
550                             reinterpret_cast<uint8_t*>(workspace.nms_scratch) +
551                             workspace.nms_scratch_space);
552   return workspace;
553 }
554 
555 template <typename DType>
ExtractScoresKernel(const DType * data,DType * scores,const index_t N,const int element_width,const int score_index)556 __global__ void ExtractScoresKernel(const DType* data, DType* scores,
557                                      const index_t N, const int element_width,
558                                      const int score_index) {
559   const index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
560   if (tid < N) {
561     scores[tid] = data[tid * element_width + score_index];
562   }
563 }
564 
565 template <typename DType>
CompactNMSResults(const Tensor<gpu,3,DType> & data,Tensor<gpu,3,DType> * out,Tensor<gpu,1,index_t> * indices,Tensor<gpu,1,DType> * scores,Tensor<gpu,1,index_t> * sorted_indices,Tensor<gpu,1,DType> * sorted_scores,Tensor<gpu,1,char> * scratch,const int score_index,const index_t topk,Stream<gpu> * s)566 void CompactNMSResults(const Tensor<gpu, 3, DType>& data,
567                        Tensor<gpu, 3, DType>* out,
568                        Tensor<gpu, 1, index_t>* indices,
569                        Tensor<gpu, 1, DType>* scores,
570                        Tensor<gpu, 1, index_t>* sorted_indices,
571                        Tensor<gpu, 1, DType>* sorted_scores,
572                        Tensor<gpu, 1, char>* scratch,
573                        const int score_index,
574                        const index_t topk,
575                        Stream<gpu>* s) {
576   using mshadow::Shape1;
577   constexpr int n_threads = 512;
578   const index_t num_elements = scores->shape_.Size();
579   const index_t num_elements_per_batch = data.shape_[1];
580   const index_t num_batches = data.shape_[0];
581   const int element_width = data.shape_[2];
582   const index_t n_blocks = ceil_div(num_elements, n_threads);
583   ExtractScoresKernel<<<n_blocks, n_threads, 0, Stream<gpu>::GetStream(s)>>>(
584       data.dptr_, scores->dptr_, num_elements, element_width, score_index);
585   *indices = mshadow::expr::range<index_t>(0, num_elements);
586   for (index_t i = 0; i < num_batches; ++i) {
587     // Sort each batch separately
588     Tensor<gpu, 1, DType> scores_batch(scores->dptr_ + i * num_elements_per_batch,
589                                        Shape1(topk),
590                                        s);
591     Tensor<gpu, 1, index_t> indices_batch(indices->dptr_ + i * num_elements_per_batch,
592                                           Shape1(topk),
593                                           s);
594     Tensor<gpu, 1, DType> sorted_scores_batch(sorted_scores->dptr_ + i * num_elements_per_batch,
595                                               Shape1(topk),
596                                               s);
597     Tensor<gpu, 1, index_t> sorted_indices_batch(sorted_indices->dptr_ + i * num_elements_per_batch,
598                                                  Shape1(topk),
599                                                  s);
600     mxnet::op::SortByKey(scores_batch, indices_batch, false, scratch,
601                          0, 8 * sizeof(DType), &sorted_scores_batch,
602                          &sorted_indices_batch);
603   }
604   CompactData<true>(*sorted_indices, data, out, topk, score_index, s);
605 }
606 
607 }  // namespace
608 
BoxNMSForwardGPU_notemp(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)609 void BoxNMSForwardGPU_notemp(const nnvm::NodeAttrs& attrs,
610                              const OpContext& ctx,
611                              const std::vector<TBlob>& inputs,
612                              const std::vector<OpReqType>& req,
613                              const std::vector<TBlob>& outputs) {
614   using mshadow::Shape1;
615   using mshadow::Shape2;
616   using mshadow::Shape3;
617   CHECK_NE(req[0], kAddTo) << "BoxNMS does not support kAddTo";
618   CHECK_NE(req[0], kWriteInplace) << "BoxNMS does not support in place computation";
619   CHECK_EQ(inputs.size(), 1U);
620   CHECK_EQ(outputs.size(), 2U) << "BoxNMS output: [output, temp]";
621   const BoxNMSParam& param = nnvm::get<BoxNMSParam>(attrs.parsed);
622   Stream<gpu> *s = ctx.get_stream<gpu>();
623   mxnet::TShape in_shape = inputs[box_nms_enum::kData].shape_;
624   int indim = in_shape.ndim();
625   int num_batch = indim <= 2? 1 : in_shape.ProdShape(0, indim - 2);
626   int num_elem = in_shape[indim - 2];
627   int width_elem = in_shape[indim - 1];
628 
629   MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
630     Tensor<gpu, 3, DType> data = inputs[box_nms_enum::kData]
631      .get_with_shape<gpu, 3, DType>(Shape3(num_batch, num_elem, width_elem), s);
632     Tensor<gpu, 3, DType> out = outputs[box_nms_enum::kOut]
633      .get_with_shape<gpu, 3, DType>(Shape3(num_batch, num_elem, width_elem), s);
634 
635     // Special case for topk == 0
636     if (param.topk == 0) {
637       if (req[0] != kNullOp &&
638           req[0] != kWriteInplace) {
639         out = mshadow::expr::F<mshadow_op::identity>(data);
640       }
641       return;
642     }
643 
644     index_t topk = param.topk > 0 ? std::min(param.topk, num_elem) : num_elem;
645     const auto& workspace = GetWorkspace<DType>(num_batch, num_elem,
646                                                 width_elem, topk, ctx);
647 
648     FilterAndPrepareAuxData(data, &out, workspace, param, s);
649     Tensor<gpu, 1, DType> scores(workspace.scores, Shape1(num_batch * num_elem), s);
650     Tensor<gpu, 1, DType> sorted_scores(workspace.scores + scores.MSize(),
651                                         Shape1(num_batch * num_elem), s);
652     Tensor<gpu, 1, index_t> indices(workspace.indices, Shape1(num_batch * num_elem), s);
653     Tensor<gpu, 1, index_t> sorted_indices(workspace.indices + indices.MSize(),
654                                            Shape1(num_batch * num_elem), s);
655     Tensor<gpu, 1, char> scratch(reinterpret_cast<char*>(workspace.scratch),
656                                         Shape1(workspace.scratch_space), s);
657     Tensor<gpu, 3, DType> buffer(workspace.buffer,
658                                  Shape3(num_batch, num_elem, width_elem), s);
659     Tensor<gpu, 2, uint32_t> nms_scratch(workspace.nms_scratch,
660                                          Shape2(NMS<DType>::THRESHOLD / (sizeof(uint32_t) * 8),
661                                                 topk * num_batch),
662                                          s);
663     indices = mshadow::expr::range<index_t>(0, num_batch * num_elem);
664     for (index_t i = 0; i < num_batch; ++i) {
665       // Sort each batch separately
666       Tensor<gpu, 1, DType> scores_batch(scores.dptr_ + i * num_elem,
667                                          Shape1(num_elem),
668                                          s);
669       Tensor<gpu, 1, index_t> indices_batch(indices.dptr_ + i * num_elem,
670                                             Shape1(num_elem),
671                                             s);
672       Tensor<gpu, 1, DType> sorted_scores_batch(sorted_scores.dptr_ + i * num_elem,
673                                                 Shape1(num_elem),
674                                                 s);
675       Tensor<gpu, 1, index_t> sorted_indices_batch(sorted_indices.dptr_ + i * num_elem,
676                                                    Shape1(num_elem),
677                                                    s);
678       mxnet::op::SortByKey(scores_batch, indices_batch, false, &scratch, 0,
679                            8 * sizeof(DType), &sorted_scores_batch,
680                            &sorted_indices_batch);
681     }
682     CompactData<false>(sorted_indices, out, &buffer, topk, -1, s);
683     NMS<DType> nms;
684     nms(&buffer, &nms_scratch, topk, param, s);
685     CompactNMSResults(buffer, &out, &indices, &scores, &sorted_indices,
686                       &sorted_scores, &scratch, param.score_index, topk, s);
687 
688     // convert encoding
689     if (param.in_format != param.out_format) {
690       if (box_common_enum::kCenter == param.out_format) {
691         mxnet::op::mxnet_op::Kernel<corner_to_center, gpu>::Launch(s, num_batch * num_elem,
692           out.dptr_ + param.coord_start, width_elem);
693       } else {
694         mxnet::op::mxnet_op::Kernel<center_to_corner, gpu>::Launch(s, num_batch * num_elem,
695           out.dptr_ + param.coord_start, width_elem);
696       }
697     }
698   });
699 }
700 
BoxNMSForwardGPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)701 void BoxNMSForwardGPU(const nnvm::NodeAttrs& attrs,
702                       const OpContext& ctx,
703                       const std::vector<TBlob>& inputs,
704                       const std::vector<OpReqType>& req,
705                       const std::vector<TBlob>& outputs) {
706   using namespace mshadow;
707   using namespace mshadow::expr;
708   using namespace mxnet_op;
709   CHECK_EQ(inputs.size(), 1U);
710   CHECK_EQ(outputs.size(), 2U) << "BoxNMS output: [output, temp]";
711   if (req[1] == kNullOp) {
712     BoxNMSForwardGPU_notemp(attrs, ctx, inputs, req, outputs);
713     return;
714   }
715   BoxNMSForward<gpu>(attrs, ctx, inputs, req, outputs);
716 }
717 
718 
719 NNVM_REGISTER_OP(_contrib_box_nms)
720 .set_attr<FCompute>("FCompute<gpu>", BoxNMSForwardGPU);
721 
722 NNVM_REGISTER_OP(_backward_contrib_box_nms)
723 .set_attr<FCompute>("FCompute<gpu>", BoxNMSBackward<gpu>);
724 
725 NNVM_REGISTER_OP(_contrib_box_iou)
726 .set_attr<FCompute>("FCompute<gpu>", BoxOverlapForward<gpu>);
727 
728 NNVM_REGISTER_OP(_backward_contrib_box_iou)
729 .set_attr<FCompute>("FCompute<gpu>", BoxOverlapBackward<gpu>);
730 
731 NNVM_REGISTER_OP(_contrib_bipartite_matching)
732 .set_attr<FCompute>("FCompute<gpu>", BipartiteMatchingForward<gpu>);
733 
734 NNVM_REGISTER_OP(_backward_contrib_bipartite_matching)
735 .set_attr<FCompute>("FCompute<gpu>", BipartiteMatchingBackward<gpu>);
736 
737 NNVM_REGISTER_OP(_contrib_box_encode)
738 .set_attr<FCompute>("FCompute<gpu>", BoxEncodeForward<gpu>);
739 
740 NNVM_REGISTER_OP(_contrib_box_decode)
741 .set_attr<FCompute>("FCompute<gpu>", BoxDecodeForward<gpu>);
742 
743 }  // namespace op
744 }  // namespace mxnet
745