1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
18"""Non-maximum suppression operator"""
19import math
20import tvm
21
22from tvm import api
23from tvm.generic import cast
24from tvm.intrin import if_then_else, log, power
25from topi.vision import non_max_suppression, get_valid_counts
26from .sort import argsort
27from .. import tag
28
29
30def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index):
31    """Low level IR to Prepare get valid count of bounding boxes
32    given a score threshold. Also moves valid boxes to the
33    top of input data.
34
35    Parameters
36    ----------
37    data: Buffer
38        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
39
40    flag : Buffer
41        2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
42
43    idx : Buffer
44        2D Buffer of valid data indices with shape [batch_size, num_anchors].
45
46    score_threshold : float32
47        Lower limit of score for valid bounding boxes.
48
49    id_index : optional, int
50        index of the class categories, -1 to disable.
51
52    score_index: optional, int
53        Index of the scores/confidence of boxes.
54
55    Returns
56    -------
57    stmt : Stmt
58        The result IR statement.
59    """
60    batch_size = data.shape[0]
61    num_anchors = data.shape[1]
62    box_data_length = data.shape[2]
63
64    ib = tvm.ir_builder.create()
65
66    data = ib.buffer_ptr(data)
67    flag = ib.buffer_ptr(flag)
68    idx = ib.buffer_ptr(idx)
69    score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold)
70    id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
71    score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
72
73    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
74    nthread_tx = max_threads
75    nthread_bx = batch_size * num_anchors // max_threads + 1
76    tx = tvm.thread_axis("threadIdx.x")
77    bx = tvm.thread_axis("blockIdx.x")
78    ib.scope_attr(tx, "thread_extent", nthread_tx)
79    ib.scope_attr(bx, "thread_extent", nthread_bx)
80    tid = bx * max_threads + tx
81
82    with ib.if_scope(tid < batch_size * num_anchors):
83        with ib.if_scope(tvm.all(data[tid * box_data_length + score_index] > score_threshold, \
84            tvm.any(id_index < 0, data[tid * box_data_length + id_index] >= 0))):
85            flag[tid] = 1
86            idx[tid] = 1
87        with ib.else_scope():
88            flag[tid] = 0
89            idx[tid] = 0
90
91    return ib.get()
92
93def get_valid_counts_upsweep(data, idx_in, idx, partial):
94    """Low level IR of first step of scan: unsweep.
95
96    Parameters
97    ----------
98    data: Buffer
99        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
100
101    idx_in : Buffer
102        2D Buffer of valid data indices with shape [batch_size, num_anchors].
103
104    idx : Buffer
105        2D Buffer of valid data indices with shape [batch_size, num_anchors].
106
107    partial : Buffer
108        2D Buffer of valid data indices with shape [batch_size, new_range].
109
110    Returns
111    -------
112    stmt : Stmt
113        The result IR statement.
114    """
115    batch_size = data.shape[0]
116    num_anchors = data.shape[1]
117    ib = tvm.ir_builder.create()
118    data = ib.buffer_ptr(data)
119    idx_in = ib.buffer_ptr(idx_in)
120    idx = ib.buffer_ptr(idx)
121    partial = ib.buffer_ptr(partial)
122    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
123    elem_per_thread = num_anchors // max_threads + 1
124    nthread_tx = max_threads
125    nthread_bx = batch_size
126    tx = tvm.thread_axis("threadIdx.x")
127    bx = tvm.thread_axis("blockIdx.x")
128    ib.scope_attr(tx, "thread_extent", nthread_tx)
129    ib.scope_attr(bx, "thread_extent", nthread_bx)
130    new_range = num_anchors // elem_per_thread + 1
131    # Scan: Upsweep:
132    with ib.if_scope(tvm.all(bx < batch_size, tx < new_range)):
133        with ib.for_range(0, elem_per_thread) as i:
134            with ib.if_scope(bx * num_anchors + \
135                             tx * elem_per_thread + i < batch_size * num_anchors):
136                with ib.if_scope(i == 0):
137                    partial[bx * new_range + tx] = idx_in[bx * num_anchors + tx * elem_per_thread]
138                    idx[bx * num_anchors + tx * elem_per_thread] = \
139                    idx_in[bx * num_anchors + tx * elem_per_thread]
140                with ib.else_scope():
141                    partial[bx * new_range + tx] += \
142                    idx_in[bx * num_anchors + tx * elem_per_thread + i]
143                    idx[bx * num_anchors + tx * elem_per_thread + i] = \
144                    idx[bx * num_anchors + tx * elem_per_thread + i - 1] + \
145                    idx_in[bx * num_anchors + tx * elem_per_thread + i]
146            ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
147                                  tvm.convert(['shared']),
148                                  tvm.expr.Call.Intrinsic, None, 0))
149    return ib.get()
150
151def get_valid_counts_scan(data, partial_in, partial):
152    """Low level IR to do scan.
153
154    Parameters
155    ----------
156    data: Buffer
157        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
158
159    idx_in : Buffer
160        2D Buffer of valid data indices with shape [batch_size, num_anchors].
161
162    idx : Buffer
163        2D Buffer of valid data indices with shape [batch_size, num_anchors].
164
165    partial : Buffer
166        2D Buffer of valid data indices with shape [batch_size, new_range].
167
168    Returns
169    -------
170    stmt : Stmt
171        The result IR statement.
172    """
173    batch_size = data.shape[0]
174    num_anchors = data.shape[1]
175    ib = tvm.ir_builder.create()
176    partial_in = ib.buffer_ptr(partial_in)
177    partial = ib.buffer_ptr(partial)
178    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
179    elem_per_thread = num_anchors // max_threads + 1
180    nthread_tx = max_threads
181    nthread_bx = batch_size
182    tx = tvm.thread_axis("threadIdx.x")
183    bx = tvm.thread_axis("blockIdx.x")
184    ib.scope_attr(tx, "thread_extent", nthread_tx)
185    ib.scope_attr(bx, "thread_extent", nthread_bx)
186    var = tvm.make.node("FloatImm", dtype="float32", value=2)
187    new_range = num_anchors // elem_per_thread + 1
188    iteration = cast(log(cast(new_range, "float32")) / math.log(2), "int32")
189    # Scan: Kogge-Stone adder
190    with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))):
191        with ib.for_range(0, iteration) as k:
192            with ib.if_scope(k == 0):
193                with ib.if_scope(tvm.all(tx > 0, tx < tvm.min(new_range, num_anchors))):
194                    partial[bx * new_range + tx] = \
195                    partial_in[bx * new_range + tx] + partial_in[bx * new_range + tx - 1]
196                with ib.else_scope():
197                    partial[bx * new_range] = partial_in[bx * new_range]
198            with ib.else_scope():
199                with ib.if_scope(tvm.all(tx >= cast(power(var, k), "int32"), \
200                                         tx < tvm.min(new_range, num_anchors))):
201                    partial[bx * new_range + tx] += \
202                    partial[bx * new_range + tx - cast(power(var, k), "int32")]
203            ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
204                                  tvm.convert(['shared']),
205                                  tvm.expr.Call.Intrinsic, None, 0))
206    return ib.get()
207
208def get_valid_counts_downsweep(data, idx_in, partial, idx):
209    """Low level IR to do downsweep of scan.
210
211    Parameters
212    ----------
213    data: Buffer
214        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
215
216    idx_in : Buffer
217        2D Buffer of valid data indices with shape [batch_size, num_anchors].
218
219    partial : Buffer
220        2D Buffer of valid data indices with shape [batch_size, new_range].
221
222    idx : Buffer
223        2D Buffer of valid data indices with shape [batch_size, num_anchors].
224
225    Returns
226    -------
227    stmt : Stmt
228        The result IR statement.
229    """
230    batch_size = data.shape[0]
231    num_anchors = data.shape[1]
232    ib = tvm.ir_builder.create()
233    idx_in = ib.buffer_ptr(idx_in)
234    idx = ib.buffer_ptr(idx)
235    partial = ib.buffer_ptr(partial)
236    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
237    elem_per_thread = num_anchors // max_threads + 1
238    nthread_tx = max_threads
239    nthread_bx = batch_size * num_anchors // max_threads + 1
240    tx = tvm.thread_axis("threadIdx.x")
241    bx = tvm.thread_axis("blockIdx.x")
242    ib.scope_attr(tx, "thread_extent", nthread_tx)
243    ib.scope_attr(bx, "thread_extent", nthread_bx)
244    tid = bx * max_threads + tx
245    new_range = num_anchors // elem_per_thread + 1
246    idxd = tvm.indexdiv
247    idxm = tvm.indexmod
248    # Scan: Downsweep:
249    with ib. if_scope(tid < batch_size * num_anchors):
250        i = idxd(tid, num_anchors) # number of batches
251        j = idxm(tid, num_anchors) # number of anchors
252        with ib.if_scope(j < elem_per_thread):
253            idx[tid] = idx_in[tid]
254        with ib.else_scope():
255            idx[tid] = idx_in[tid] + partial[i * new_range + idxd(j, elem_per_thread) - 1]
256
257    return ib.get()
258
259def get_valid_counts_ir(data, flag, idx, valid_count, out):
260    """Low level IR to get valid count of bounding boxes
261    given a score threshold. Also moves valid boxes to the
262    top of input data.
263
264    Parameters
265    ----------
266    data : Buffer
267        Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length].
268
269    flag : Buffer
270        2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
271
272    idx : Buffer
273        2D Buffer of valid data indices with shape [batch_size, num_anchors].
274
275    valid_count : Buffer
276        1-D buffer for valid number of boxes.
277
278    out : Buffer
279        Rearranged data buffer.
280
281    Returns
282    -------
283    stmt : Stmt
284        The result IR statement.
285    """
286    batch_size = data.shape[0]
287    num_anchors = data.shape[1]
288    elem_length = data.shape[2]
289    size = batch_size * num_anchors * elem_length
290
291    ib = tvm.ir_builder.create()
292
293    data = ib.buffer_ptr(data)
294    flag = ib.buffer_ptr(flag)
295    idx = ib.buffer_ptr(idx)
296    valid_count = ib.buffer_ptr(valid_count)
297    out = ib.buffer_ptr(out)
298
299    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
300    nthread_tx = max_threads
301    nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1
302    tx = tvm.thread_axis("threadIdx.x")
303    bx = tvm.thread_axis("blockIdx.x")
304    ib.scope_attr(tx, "thread_extent", nthread_tx)
305    ib.scope_attr(bx, "thread_extent", nthread_bx)
306    tid = bx * max_threads + tx
307
308    idxd = tvm.indexdiv
309    idxm = tvm.indexmod
310
311    with ib.if_scope(tid < batch_size * num_anchors):
312        i = idxd(tid, num_anchors)
313        j = idxm(tid, num_anchors)
314        base_idx = i * num_anchors * elem_length
315        with ib.if_scope(flag[tid] > 0):
316            with ib.for_range(0, elem_length) as k:
317                with ib.if_scope(base_idx + (idx[tid] - 1) * elem_length + k < size):
318                    out[base_idx + (idx[tid] - 1) * elem_length + k] =\
319                    data[base_idx + j * elem_length + k]
320        with ib.if_scope(j == 0):
321            valid_count[i] = idx[tid + num_anchors - 1]
322        with ib.if_scope(j >= idx[i * num_anchors + num_anchors - 1]):
323            with ib.for_range(0, elem_length) as l:
324                with ib.if_scope(tid * elem_length + l < size):
325                    out[tid * elem_length + l] = -1.0
326    return ib.get()
327
328
329@get_valid_counts.register(["cuda", "gpu"])
330def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1):
331    """Get valid count of bounding boxes given a score threshold.
332    Also moves valid boxes to the top of input data.
333
334    Parameters
335    ----------
336    data : tvm.Tensor
337        Input data. 3-D tensor with shape [batch_size, num_anchors, elem_length].
338
339    score_threshold : optional, float
340        Lower limit of score for valid bounding boxes.
341
342    id_index : optional, int
343        index of the class categories, -1 to disable.
344
345    score_index: optional, int
346        Index of the scores/confidence of boxes.
347
348    Returns
349    -------
350    valid_count : tvm.Tensor
351        1-D tensor for valid number of boxes.
352
353    out_tensor : tvm.Tensor
354        Rearranged data tensor.
355    """
356    batch_size = data.shape[0]
357    num_anchors = data.shape[1]
358    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
359    elem_per_thread = num_anchors // max_threads + 1
360    new_range = num_anchors // elem_per_thread + 1
361    temp_flag_buf = api.decl_buffer(
362        (batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8)
363    temp_idx_buf = api.decl_buffer(
364        (batch_size, num_anchors,), "int32", "temp_idx", data_alignment=8)
365    temp_partial_buf = api.decl_buffer(
366        (batch_size, new_range), "int32", "temp_partial", data_alignment=8)
367    data_buf = api.decl_buffer(
368        data.shape, data.dtype, "data_buf", data_alignment=8)
369
370    temp_flag, temp_idx = \
371        tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data],
372                   lambda ins, outs: get_valid_counts_pre(
373                       ins[0], outs[0], outs[1], score_threshold, id_index, score_index),
374                   dtype=["int32", "int32"],
375                   out_buffers=[temp_flag_buf, temp_idx_buf],
376                   name="get_valid_counts_phase_one")
377    temp_idx_new, temp_partial = \
378        tvm.extern([(batch_size, num_anchors,), (batch_size, new_range)], [data, temp_idx],
379                   lambda ins, outs: get_valid_counts_upsweep(
380                       ins[0], ins[1], outs[0], outs[1]),
381                   dtype=["int32", "int32"],
382                   out_buffers=[temp_idx_buf, temp_partial_buf],
383                   name="get_valid_counts_phase_two")
384    temp_partial_new = \
385        tvm.extern([(batch_size, new_range)], [data, temp_partial],
386                   lambda ins, outs: get_valid_counts_scan(
387                       ins[0], ins[1], outs[0]),
388                   dtype=["int32"],
389                   out_buffers=[temp_partial_buf],
390                   name="get_valid_counts_phase_three")
391    temp_idx_final = \
392        tvm.extern([(batch_size, num_anchors)], [data, temp_idx_new, temp_partial_new],
393                   lambda ins, outs: get_valid_counts_downsweep(
394                       ins[0], ins[1], ins[2], outs[0]),
395                   dtype=["int32"],
396                   out_buffers=[temp_idx_buf],
397                   name="get_valid_counts_phase_four")
398    valid_count, out_tensor = \
399	tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final],
400            lambda ins, outs: get_valid_counts_ir(
401                ins[0], ins[1], ins[2], outs[0], outs[1]),
402            dtype=["int32", data.dtype],
403            in_buffers=[data_buf, temp_flag_buf, temp_idx_buf],
404            name="get_valid_counts_phase_five",
405            tag="get_valid_counts_gpu")
406
407    return [valid_count, out_tensor]
408
409
410def nms_ir(data, sorted_index, valid_count, out, box_indices,
411           max_output_size, iou_threshold, force_suppress,
412           top_k, coord_start, id_index, score_index):
413    """Low level IR routing for transform location in multibox_detection operator.
414
415    Parameters
416    ----------
417    data : Buffer
418        Buffer of output boxes with class and score.
419
420    sort_index : Buffer
421        Buffer of output box indexes sorted by score.
422
423    valid_count : Buffer
424        Buffer of number of valid output boxes.
425
426    out : Buffer
427        Output buffer.
428
429    max_output_size : int
430        Max number of output valid boxes for each instance.
431        By default all valid boxes are returned.
432
433    iou_threshold : float
434        Overlapping(IoU) threshold to suppress object with smaller score.
435
436    force_suppress : boolean
437        Whether to suppress all detections regardless of class_id.
438
439    top_k : int
440        Keep maximum top k detections before nms, -1 for no limit.
441
442    coord_start : int
443        Start index of the consecutive 4 coordinates.
444
445    id_index : int
446        index of the class categories, -1 to disable.
447
448    score_index : optional, int
449        Index of the scores/confidence of boxes.
450
451    Returns
452    -------
453    stmt : Stmt
454        The result IR statement.
455    """
456    def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
457        """Calculate overlap of two boxes.
458        """
459        w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2])
460                    - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx]))
461        h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3])
462                    - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1]))
463        i = w * h
464        u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \
465            (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \
466            (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \
467            (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i
468        return tvm.expr.Select(u <= 0.0, 0.0, i / u)
469
470    batch_size = data.shape[0]
471    num_anchors = data.shape[1]
472    box_data_length = data.shape[2]
473
474    ib = tvm.ir_builder.create()
475
476    data = ib.buffer_ptr(data)
477    sorted_index = ib.buffer_ptr(sorted_index)
478    valid_count = ib.buffer_ptr(valid_count)
479    out = ib.buffer_ptr(out)
480    box_indices = ib.buffer_ptr(box_indices)
481    num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local")
482
483    max_threads = int(
484        tvm.target.current_target(allow_none=False).max_num_threads)
485    nthread_tx = max_threads
486    nthread_bx = num_anchors // max_threads + 1
487    tx = tvm.thread_axis("threadIdx.x")
488    bx = tvm.thread_axis("blockIdx.x")
489    ib.scope_attr(tx, "thread_extent", nthread_tx)
490    ib.scope_attr(bx, "thread_extent", nthread_bx)
491    j = bx * max_threads + tx
492
493    iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold)
494    top_k = tvm.make.node("IntImm", dtype="int32", value=top_k)
495    coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start)
496    id_index = tvm.make.node("IntImm", dtype="int32", value=id_index)
497    score_index = tvm.make.node("IntImm", dtype="int32", value=score_index)
498    force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0)
499
500    with ib.for_range(0, batch_size, for_type="unroll") as i:
501        base_idx = i * num_anchors * box_data_length
502        with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)):
503            # Reorder output
504            nkeep = if_then_else( \
505                    tvm.all(top_k > 0, top_k < valid_count[i]),
506                    top_k, valid_count[i])
507            with ib.if_scope(j < nkeep):
508                with ib.for_range(0, box_data_length) as k:
509                    out[(base_idx + j * box_data_length + k)] = \
510                    data[(base_idx + sorted_index[i * num_anchors + j] \
511                    * box_data_length + k)]
512                box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j]
513            with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])):
514                with ib.if_scope(j < valid_count[i] - nkeep):
515                    with ib.for_range(0, box_data_length) as k:
516                        out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0
517                    box_indices[i * num_anchors + (j + nkeep)] = -1
518            # Apply nms
519            with ib.for_range(0, valid_count[i]) as k:
520                offset_k = k * box_data_length
521                with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0, \
522                    tvm.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0))):
523                    with ib.if_scope(j < valid_count[i]):
524                        offset_j = j * box_data_length
525                        with ib.if_scope(tvm.all(j > k, \
526                            out[base_idx + offset_j + score_index] > 0, \
527                                                 tvm.any(id_index < 0, \
528                                                    out[base_idx + offset_j + id_index] >= 0), \
529						 tvm.any(force_suppress > 0, id_index < 0, \
530                                                         out[base_idx + offset_k + id_index] == \
531                                                         out[base_idx + offset_j + id_index]))):
532                            iou = calculate_overlap(out, base_idx + offset_j + coord_start,
533                                                    base_idx + offset_k + coord_start)
534                            with ib.if_scope(iou >= iou_threshold):
535                                out[base_idx + offset_j + score_index] = -1.0
536                                with ib.if_scope(id_index >= 0):
537                                    out[base_idx + offset_j + id_index] = -1.0
538                                box_indices[i * num_anchors + j] = -1
539        with ib.else_scope():
540            with ib.if_scope(j < valid_count[i]):
541                offset_j = j * box_data_length
542                with ib.for_range(0, box_data_length) as k:
543                    out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k]
544                box_indices[i * num_anchors + j] = j
545        # Set invalid entry to be -1
546        with ib.if_scope(j < num_anchors - valid_count[i]):
547            with ib.for_range(0, box_data_length) as k:
548                out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0
549            box_indices[i * num_anchors + j + valid_count[i]] = -1
550        # Only return max_output_size number of valid boxes
551        num_valid_boxes[0] = 0
552        with ib.if_scope(max_output_size > 0):
553            with ib.if_scope(j < valid_count[i]):
554                offset_j = j * box_data_length
555                with ib.if_scope(out[base_idx + offset_j] >= 0):
556                    with ib.if_scope(num_valid_boxes[0] == max_output_size):
557                        with ib.for_range(0, box_data_length) as k:
558                            out[base_idx + offset_j + k] = -1.0
559                        box_indices[i * num_anchors + j] = -1
560                    with ib.else_scope():
561                        num_valid_boxes[0] += 1
562
563    return ib.get()
564
565
566def invalid_to_bottom_pre(data, flag, idx):
567    """Low level IR to rearrange nms output to move all valid entries to top.
568
569    Parameters
570    ----------
571    data: Buffer
572        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
573
574    flag : Buffer
575        1D Buffer of flag indicating valid data with [num_anchors].
576
577    idx : Buffer
578        1D Buffer of valid data indices with [num_anchors].
579
580    Returns
581    -------
582    stmt : Stmt
583        The result IR statement.
584    """
585    batch_size = data.shape[0]
586    num_anchors = data.shape[1]
587    elem_length = data.shape[2]
588
589    ib = tvm.ir_builder.create()
590
591    data = ib.buffer_ptr(data)
592    flag = ib.buffer_ptr(flag)
593    idx = ib.buffer_ptr(idx)
594
595    max_threads = int(math.sqrt(
596        tvm.target.current_target(allow_none=False).max_num_threads))
597    nthread_tx = max_threads
598    nthread_bx = num_anchors // max_threads + 1
599    tx = tvm.thread_axis("threadIdx.x")
600    bx = tvm.thread_axis("blockIdx.x")
601    ib.scope_attr(tx, "thread_extent", nthread_tx)
602    ib.scope_attr(bx, "thread_extent", nthread_bx)
603    j = bx * max_threads + tx
604
605    with ib.for_range(0, batch_size, for_type="unroll") as i:
606        base_idx = i * num_anchors * elem_length
607        with ib.if_scope(j < num_anchors):
608            with ib.if_scope(data[base_idx + j * elem_length] >= 0):
609                flag[i * num_anchors + j] = 1
610                idx[i * num_anchors + j] = 1
611            with ib.else_scope():
612                flag[i * num_anchors + j] = 0
613                idx[i * num_anchors + j] = 0
614
615    with ib.if_scope(j < batch_size):
616        with ib.for_range(0, num_anchors) as k:
617            with ib.if_scope(k > 0):
618                idx[j * num_anchors + k] += idx[j * num_anchors + k - 1]
619    return ib.get()
620
621
622def invalid_to_bottom_ir(data, flag, idx, out):
623    """Low level IR to rearrange nms output to move all valid entries to top.
624
625    Parameters
626    ----------
627    data: Buffer
628        3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms.
629
630    flag : Buffer
631        1D Buffer of flag indicating valid data with [num_anchors].
632
633    idx : Buffer
634        1D Buffer of valid data indices with [num_anchors].
635
636    out : Buffer
637        3D Buffer of rearranged nms output with shape [batch_size, num_anchors, elem_length].
638
639    Returns
640    -------
641    stmt : Stmt
642        The result IR statement.
643    """
644    batch_size = data.shape[0]
645    num_anchors = data.shape[1]
646    elem_length = data.shape[2]
647
648    ib = tvm.ir_builder.create()
649
650    data = ib.buffer_ptr(data)
651    flag = ib.buffer_ptr(flag)
652    idx = ib.buffer_ptr(idx)
653    out = ib.buffer_ptr(out)
654
655    max_threads = int(math.sqrt(
656        tvm.target.current_target(allow_none=False).max_num_threads))
657    nthread_tx = max_threads
658    nthread_bx = num_anchors // max_threads + 1
659    tx = tvm.thread_axis("threadIdx.x")
660    bx = tvm.thread_axis("blockIdx.x")
661    ib.scope_attr(tx, "thread_extent", nthread_tx)
662    ib.scope_attr(bx, "thread_extent", nthread_bx)
663    j = bx * max_threads + tx
664
665    with ib.for_range(0, batch_size, for_type="unroll") as i:
666        base_idx = i * num_anchors * elem_length
667        with ib.if_scope(j < num_anchors):
668            with ib.for_range(0, elem_length) as k:
669                out[base_idx + j * elem_length + k] = -1.0
670            with ib.if_scope(flag[i * num_anchors + j] > 0):
671                with ib.for_range(0, elem_length) as k:
672                    out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \
673                    = data[base_idx + j * elem_length + k]
674    return ib.get()
675
676
677@non_max_suppression.register(["cuda", "gpu"])
678def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
679                            iou_threshold=0.5, force_suppress=False, top_k=-1,
680                            coord_start=2, score_index=1, id_index=0,
681                            return_indices=True, invalid_to_bottom=False):
682    """Non-maximum suppression operator for object detection.
683
684    Parameters
685    ----------
686    data : tvm.Tensor
687        3-D tensor with shape [batch_size, num_anchors, elem_length].
688        The last dimension should be in format of
689        [class_id, score, box_left, box_top, box_right, box_bottom].
690
691    valid_count : tvm.Tensor
692        1-D tensor for valid number of boxes.
693
694    max_output_size : optional, int
695        Max number of output valid boxes for each instance.
696        By default all valid boxes are returned.
697
698    iou_threshold : optional, float
699        Non-maximum suppression threshold.
700
701    force_suppress : optional, boolean
702        Whether to suppress all detections regardless of class_id.
703
704    top_k : optional, int
705        Keep maximum top k detections before nms, -1 for no limit.
706
707    coord_start : required, int
708        Start index of the consecutive 4 coordinates.
709
710    score_index : optional, int
711        Index of the scores/confidence of boxes.
712
713    id_index : optional, int
714        index of the class categories, -1 to disable.
715
716    return_indices : boolean
717        Whether to return box indices in input data.
718
719    invalid_to_bottom : optional, boolean
720        Whether to move all valid bounding boxes to the top.
721
722    Returns
723    -------
724    out : tvm.Tensor
725        3-D tensor with shape [batch_size, num_anchors, elem_length].
726
727    Example
728    --------
729    .. code-block:: python
730
731        # An example to use nms
732        dshape = (1, 5, 6)
733        data = tvm.placeholder(dshape, name="data")
734        valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
735        iou_threshold = 0.7
736        force_suppress = True
737        top_k = -1
738        out = non_max_suppression(data=data, valid_count=valid_count, iou_threshold=iou_threshold,
739                                 force_suppress=force_supress, top_k=top_k, return_indices=False)
740        np_data = np.random.uniform(dshape)
741        np_valid_count = np.array([4])
742        s = topi.generic.schedule_nms(out)
743        f = tvm.build(s, [data, valid_count, out], "cuda")
744        ctx = tvm.gpu(0)
745        tvm_data = tvm.nd.array(np_data, ctx)
746        tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
747        tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
748        f(tvm_data, tvm_valid_count, tvm_out)
749    """
750    batch_size = data.shape[0]
751    num_anchors = data.shape[1]
752
753    valid_count_dtype = "int32"
754    valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype,
755                                      "valid_count_buf", data_alignment=4)
756    score_axis = score_index
757    score_shape = (batch_size, num_anchors)
758    score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
759    sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
760
761    sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
762                                      "sort_tensor_buf", data_alignment=8)
763
764    data_buf = api.decl_buffer(
765        data.shape, data.dtype, "data_buf", data_alignment=8)
766
767    out_buf = api.decl_buffer(
768        data.shape, data.dtype, "out_buf", data_alignment=8)
769
770    out, box_indices = \
771        tvm.extern([data.shape, score_shape],
772                   [data, sort_tensor, valid_count],
773                   lambda ins, outs: nms_ir(
774                       ins[0], ins[1], ins[2], outs[0], outs[1],
775                       max_output_size, iou_threshold, force_suppress,
776                       top_k, coord_start, id_index, score_index),
777                   dtype=[data.dtype, "int32"],
778                   in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
779                   name="nms",
780                   tag="nms")
781
782    if return_indices:
783        return box_indices
784
785    if invalid_to_bottom:
786        output_buf = api.decl_buffer(
787            data.shape, data.dtype, "output_buf", data_alignment=8)
788        temp_flag_buf = api.decl_buffer(
789            score_shape, valid_count_dtype, "temp_flag", data_alignment=8)
790        temp_idx_buf = api.decl_buffer(
791            score_shape, valid_count_dtype, "temp_idx", data_alignment=8)
792        temp_flag, temp_idx = tvm.extern([score_shape, score_shape], [out],
793                                         lambda ins, outs: invalid_to_bottom_pre(
794                                             ins[0], outs[0], outs[1]),
795                                         dtype=["int32", "int32"],
796                                         in_buffers=[out_buf],
797                                         out_buffers=[temp_flag_buf, temp_idx_buf],
798                                         name="invalid_to_bottom_phase_one")
799
800        output = tvm.extern([data.shape], [out, temp_flag, temp_idx],
801                            lambda ins, outs: invalid_to_bottom_ir(
802                                ins[0], ins[1], ins[2], outs[0]),
803                            dtype=[data.dtype],
804                            in_buffers=[out_buf, temp_flag_buf, temp_idx_buf],
805                            out_buffers=[output_buf],
806                            name="invalid_to_bottom",
807                            tag="invalid_to_bottom")
808        return output
809
810    return out
811