1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument 18"""Non-maximum suppression operator""" 19import math 20import tvm 21 22from tvm import api 23from tvm.generic import cast 24from tvm.intrin import if_then_else, log, power 25from topi.vision import non_max_suppression, get_valid_counts 26from .sort import argsort 27from .. import tag 28 29 30def get_valid_counts_pre(data, flag, idx, score_threshold, id_index, score_index): 31 """Low level IR to Prepare get valid count of bounding boxes 32 given a score threshold. Also moves valid boxes to the 33 top of input data. 34 35 Parameters 36 ---------- 37 data: Buffer 38 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. 39 40 flag : Buffer 41 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. 42 43 idx : Buffer 44 2D Buffer of valid data indices with shape [batch_size, num_anchors]. 45 46 score_threshold : float32 47 Lower limit of score for valid bounding boxes. 48 49 id_index : optional, int 50 index of the class categories, -1 to disable. 51 52 score_index: optional, int 53 Index of the scores/confidence of boxes. 54 55 Returns 56 ------- 57 stmt : Stmt 58 The result IR statement. 59 """ 60 batch_size = data.shape[0] 61 num_anchors = data.shape[1] 62 box_data_length = data.shape[2] 63 64 ib = tvm.ir_builder.create() 65 66 data = ib.buffer_ptr(data) 67 flag = ib.buffer_ptr(flag) 68 idx = ib.buffer_ptr(idx) 69 score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold) 70 id_index = tvm.make.node("IntImm", dtype="int32", value=id_index) 71 score_index = tvm.make.node("IntImm", dtype="int32", value=score_index) 72 73 max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) 74 nthread_tx = max_threads 75 nthread_bx = batch_size * num_anchors // max_threads + 1 76 tx = tvm.thread_axis("threadIdx.x") 77 bx = tvm.thread_axis("blockIdx.x") 78 ib.scope_attr(tx, "thread_extent", nthread_tx) 79 ib.scope_attr(bx, "thread_extent", nthread_bx) 80 tid = bx * max_threads + tx 81 82 with ib.if_scope(tid < batch_size * num_anchors): 83 with ib.if_scope(tvm.all(data[tid * box_data_length + score_index] > score_threshold, \ 84 tvm.any(id_index < 0, data[tid * box_data_length + id_index] >= 0))): 85 flag[tid] = 1 86 idx[tid] = 1 87 with ib.else_scope(): 88 flag[tid] = 0 89 idx[tid] = 0 90 91 return ib.get() 92 93def get_valid_counts_upsweep(data, idx_in, idx, partial): 94 """Low level IR of first step of scan: unsweep. 95 96 Parameters 97 ---------- 98 data: Buffer 99 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. 100 101 idx_in : Buffer 102 2D Buffer of valid data indices with shape [batch_size, num_anchors]. 103 104 idx : Buffer 105 2D Buffer of valid data indices with shape [batch_size, num_anchors]. 106 107 partial : Buffer 108 2D Buffer of valid data indices with shape [batch_size, new_range]. 109 110 Returns 111 ------- 112 stmt : Stmt 113 The result IR statement. 114 """ 115 batch_size = data.shape[0] 116 num_anchors = data.shape[1] 117 ib = tvm.ir_builder.create() 118 data = ib.buffer_ptr(data) 119 idx_in = ib.buffer_ptr(idx_in) 120 idx = ib.buffer_ptr(idx) 121 partial = ib.buffer_ptr(partial) 122 max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) 123 elem_per_thread = num_anchors // max_threads + 1 124 nthread_tx = max_threads 125 nthread_bx = batch_size 126 tx = tvm.thread_axis("threadIdx.x") 127 bx = tvm.thread_axis("blockIdx.x") 128 ib.scope_attr(tx, "thread_extent", nthread_tx) 129 ib.scope_attr(bx, "thread_extent", nthread_bx) 130 new_range = num_anchors // elem_per_thread + 1 131 # Scan: Upsweep: 132 with ib.if_scope(tvm.all(bx < batch_size, tx < new_range)): 133 with ib.for_range(0, elem_per_thread) as i: 134 with ib.if_scope(bx * num_anchors + \ 135 tx * elem_per_thread + i < batch_size * num_anchors): 136 with ib.if_scope(i == 0): 137 partial[bx * new_range + tx] = idx_in[bx * num_anchors + tx * elem_per_thread] 138 idx[bx * num_anchors + tx * elem_per_thread] = \ 139 idx_in[bx * num_anchors + tx * elem_per_thread] 140 with ib.else_scope(): 141 partial[bx * new_range + tx] += \ 142 idx_in[bx * num_anchors + tx * elem_per_thread + i] 143 idx[bx * num_anchors + tx * elem_per_thread + i] = \ 144 idx[bx * num_anchors + tx * elem_per_thread + i - 1] + \ 145 idx_in[bx * num_anchors + tx * elem_per_thread + i] 146 ib.emit(tvm.make.Call(None, 'tvm_storage_sync', 147 tvm.convert(['shared']), 148 tvm.expr.Call.Intrinsic, None, 0)) 149 return ib.get() 150 151def get_valid_counts_scan(data, partial_in, partial): 152 """Low level IR to do scan. 153 154 Parameters 155 ---------- 156 data: Buffer 157 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. 158 159 idx_in : Buffer 160 2D Buffer of valid data indices with shape [batch_size, num_anchors]. 161 162 idx : Buffer 163 2D Buffer of valid data indices with shape [batch_size, num_anchors]. 164 165 partial : Buffer 166 2D Buffer of valid data indices with shape [batch_size, new_range]. 167 168 Returns 169 ------- 170 stmt : Stmt 171 The result IR statement. 172 """ 173 batch_size = data.shape[0] 174 num_anchors = data.shape[1] 175 ib = tvm.ir_builder.create() 176 partial_in = ib.buffer_ptr(partial_in) 177 partial = ib.buffer_ptr(partial) 178 max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) 179 elem_per_thread = num_anchors // max_threads + 1 180 nthread_tx = max_threads 181 nthread_bx = batch_size 182 tx = tvm.thread_axis("threadIdx.x") 183 bx = tvm.thread_axis("blockIdx.x") 184 ib.scope_attr(tx, "thread_extent", nthread_tx) 185 ib.scope_attr(bx, "thread_extent", nthread_bx) 186 var = tvm.make.node("FloatImm", dtype="float32", value=2) 187 new_range = num_anchors // elem_per_thread + 1 188 iteration = cast(log(cast(new_range, "float32")) / math.log(2), "int32") 189 # Scan: Kogge-Stone adder 190 with ib.if_scope(tvm.all(bx < batch_size, tx < tvm.min(new_range, num_anchors))): 191 with ib.for_range(0, iteration) as k: 192 with ib.if_scope(k == 0): 193 with ib.if_scope(tvm.all(tx > 0, tx < tvm.min(new_range, num_anchors))): 194 partial[bx * new_range + tx] = \ 195 partial_in[bx * new_range + tx] + partial_in[bx * new_range + tx - 1] 196 with ib.else_scope(): 197 partial[bx * new_range] = partial_in[bx * new_range] 198 with ib.else_scope(): 199 with ib.if_scope(tvm.all(tx >= cast(power(var, k), "int32"), \ 200 tx < tvm.min(new_range, num_anchors))): 201 partial[bx * new_range + tx] += \ 202 partial[bx * new_range + tx - cast(power(var, k), "int32")] 203 ib.emit(tvm.make.Call(None, 'tvm_storage_sync', 204 tvm.convert(['shared']), 205 tvm.expr.Call.Intrinsic, None, 0)) 206 return ib.get() 207 208def get_valid_counts_downsweep(data, idx_in, partial, idx): 209 """Low level IR to do downsweep of scan. 210 211 Parameters 212 ---------- 213 data: Buffer 214 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. 215 216 idx_in : Buffer 217 2D Buffer of valid data indices with shape [batch_size, num_anchors]. 218 219 partial : Buffer 220 2D Buffer of valid data indices with shape [batch_size, new_range]. 221 222 idx : Buffer 223 2D Buffer of valid data indices with shape [batch_size, num_anchors]. 224 225 Returns 226 ------- 227 stmt : Stmt 228 The result IR statement. 229 """ 230 batch_size = data.shape[0] 231 num_anchors = data.shape[1] 232 ib = tvm.ir_builder.create() 233 idx_in = ib.buffer_ptr(idx_in) 234 idx = ib.buffer_ptr(idx) 235 partial = ib.buffer_ptr(partial) 236 max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) 237 elem_per_thread = num_anchors // max_threads + 1 238 nthread_tx = max_threads 239 nthread_bx = batch_size * num_anchors // max_threads + 1 240 tx = tvm.thread_axis("threadIdx.x") 241 bx = tvm.thread_axis("blockIdx.x") 242 ib.scope_attr(tx, "thread_extent", nthread_tx) 243 ib.scope_attr(bx, "thread_extent", nthread_bx) 244 tid = bx * max_threads + tx 245 new_range = num_anchors // elem_per_thread + 1 246 idxd = tvm.indexdiv 247 idxm = tvm.indexmod 248 # Scan: Downsweep: 249 with ib. if_scope(tid < batch_size * num_anchors): 250 i = idxd(tid, num_anchors) # number of batches 251 j = idxm(tid, num_anchors) # number of anchors 252 with ib.if_scope(j < elem_per_thread): 253 idx[tid] = idx_in[tid] 254 with ib.else_scope(): 255 idx[tid] = idx_in[tid] + partial[i * new_range + idxd(j, elem_per_thread) - 1] 256 257 return ib.get() 258 259def get_valid_counts_ir(data, flag, idx, valid_count, out): 260 """Low level IR to get valid count of bounding boxes 261 given a score threshold. Also moves valid boxes to the 262 top of input data. 263 264 Parameters 265 ---------- 266 data : Buffer 267 Input data. 3-D Buffer with shape [batch_size, num_anchors, elem_length]. 268 269 flag : Buffer 270 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. 271 272 idx : Buffer 273 2D Buffer of valid data indices with shape [batch_size, num_anchors]. 274 275 valid_count : Buffer 276 1-D buffer for valid number of boxes. 277 278 out : Buffer 279 Rearranged data buffer. 280 281 Returns 282 ------- 283 stmt : Stmt 284 The result IR statement. 285 """ 286 batch_size = data.shape[0] 287 num_anchors = data.shape[1] 288 elem_length = data.shape[2] 289 size = batch_size * num_anchors * elem_length 290 291 ib = tvm.ir_builder.create() 292 293 data = ib.buffer_ptr(data) 294 flag = ib.buffer_ptr(flag) 295 idx = ib.buffer_ptr(idx) 296 valid_count = ib.buffer_ptr(valid_count) 297 out = ib.buffer_ptr(out) 298 299 max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) 300 nthread_tx = max_threads 301 nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1 302 tx = tvm.thread_axis("threadIdx.x") 303 bx = tvm.thread_axis("blockIdx.x") 304 ib.scope_attr(tx, "thread_extent", nthread_tx) 305 ib.scope_attr(bx, "thread_extent", nthread_bx) 306 tid = bx * max_threads + tx 307 308 idxd = tvm.indexdiv 309 idxm = tvm.indexmod 310 311 with ib.if_scope(tid < batch_size * num_anchors): 312 i = idxd(tid, num_anchors) 313 j = idxm(tid, num_anchors) 314 base_idx = i * num_anchors * elem_length 315 with ib.if_scope(flag[tid] > 0): 316 with ib.for_range(0, elem_length) as k: 317 with ib.if_scope(base_idx + (idx[tid] - 1) * elem_length + k < size): 318 out[base_idx + (idx[tid] - 1) * elem_length + k] =\ 319 data[base_idx + j * elem_length + k] 320 with ib.if_scope(j == 0): 321 valid_count[i] = idx[tid + num_anchors - 1] 322 with ib.if_scope(j >= idx[i * num_anchors + num_anchors - 1]): 323 with ib.for_range(0, elem_length) as l: 324 with ib.if_scope(tid * elem_length + l < size): 325 out[tid * elem_length + l] = -1.0 326 return ib.get() 327 328 329@get_valid_counts.register(["cuda", "gpu"]) 330def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1): 331 """Get valid count of bounding boxes given a score threshold. 332 Also moves valid boxes to the top of input data. 333 334 Parameters 335 ---------- 336 data : tvm.Tensor 337 Input data. 3-D tensor with shape [batch_size, num_anchors, elem_length]. 338 339 score_threshold : optional, float 340 Lower limit of score for valid bounding boxes. 341 342 id_index : optional, int 343 index of the class categories, -1 to disable. 344 345 score_index: optional, int 346 Index of the scores/confidence of boxes. 347 348 Returns 349 ------- 350 valid_count : tvm.Tensor 351 1-D tensor for valid number of boxes. 352 353 out_tensor : tvm.Tensor 354 Rearranged data tensor. 355 """ 356 batch_size = data.shape[0] 357 num_anchors = data.shape[1] 358 max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) 359 elem_per_thread = num_anchors // max_threads + 1 360 new_range = num_anchors // elem_per_thread + 1 361 temp_flag_buf = api.decl_buffer( 362 (batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8) 363 temp_idx_buf = api.decl_buffer( 364 (batch_size, num_anchors,), "int32", "temp_idx", data_alignment=8) 365 temp_partial_buf = api.decl_buffer( 366 (batch_size, new_range), "int32", "temp_partial", data_alignment=8) 367 data_buf = api.decl_buffer( 368 data.shape, data.dtype, "data_buf", data_alignment=8) 369 370 temp_flag, temp_idx = \ 371 tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data], 372 lambda ins, outs: get_valid_counts_pre( 373 ins[0], outs[0], outs[1], score_threshold, id_index, score_index), 374 dtype=["int32", "int32"], 375 out_buffers=[temp_flag_buf, temp_idx_buf], 376 name="get_valid_counts_phase_one") 377 temp_idx_new, temp_partial = \ 378 tvm.extern([(batch_size, num_anchors,), (batch_size, new_range)], [data, temp_idx], 379 lambda ins, outs: get_valid_counts_upsweep( 380 ins[0], ins[1], outs[0], outs[1]), 381 dtype=["int32", "int32"], 382 out_buffers=[temp_idx_buf, temp_partial_buf], 383 name="get_valid_counts_phase_two") 384 temp_partial_new = \ 385 tvm.extern([(batch_size, new_range)], [data, temp_partial], 386 lambda ins, outs: get_valid_counts_scan( 387 ins[0], ins[1], outs[0]), 388 dtype=["int32"], 389 out_buffers=[temp_partial_buf], 390 name="get_valid_counts_phase_three") 391 temp_idx_final = \ 392 tvm.extern([(batch_size, num_anchors)], [data, temp_idx_new, temp_partial_new], 393 lambda ins, outs: get_valid_counts_downsweep( 394 ins[0], ins[1], ins[2], outs[0]), 395 dtype=["int32"], 396 out_buffers=[temp_idx_buf], 397 name="get_valid_counts_phase_four") 398 valid_count, out_tensor = \ 399 tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx_final], 400 lambda ins, outs: get_valid_counts_ir( 401 ins[0], ins[1], ins[2], outs[0], outs[1]), 402 dtype=["int32", data.dtype], 403 in_buffers=[data_buf, temp_flag_buf, temp_idx_buf], 404 name="get_valid_counts_phase_five", 405 tag="get_valid_counts_gpu") 406 407 return [valid_count, out_tensor] 408 409 410def nms_ir(data, sorted_index, valid_count, out, box_indices, 411 max_output_size, iou_threshold, force_suppress, 412 top_k, coord_start, id_index, score_index): 413 """Low level IR routing for transform location in multibox_detection operator. 414 415 Parameters 416 ---------- 417 data : Buffer 418 Buffer of output boxes with class and score. 419 420 sort_index : Buffer 421 Buffer of output box indexes sorted by score. 422 423 valid_count : Buffer 424 Buffer of number of valid output boxes. 425 426 out : Buffer 427 Output buffer. 428 429 max_output_size : int 430 Max number of output valid boxes for each instance. 431 By default all valid boxes are returned. 432 433 iou_threshold : float 434 Overlapping(IoU) threshold to suppress object with smaller score. 435 436 force_suppress : boolean 437 Whether to suppress all detections regardless of class_id. 438 439 top_k : int 440 Keep maximum top k detections before nms, -1 for no limit. 441 442 coord_start : int 443 Start index of the consecutive 4 coordinates. 444 445 id_index : int 446 index of the class categories, -1 to disable. 447 448 score_index : optional, int 449 Index of the scores/confidence of boxes. 450 451 Returns 452 ------- 453 stmt : Stmt 454 The result IR statement. 455 """ 456 def calculate_overlap(out_tensor, box_a_idx, box_b_idx): 457 """Calculate overlap of two boxes. 458 """ 459 w = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 2], out_tensor[box_b_idx + 2]) 460 - tvm.max(out_tensor[box_a_idx], out_tensor[box_b_idx])) 461 h = tvm.max(0.0, tvm.min(out_tensor[box_a_idx + 3], out_tensor[box_b_idx + 3]) 462 - tvm.max(out_tensor[box_a_idx + 1], out_tensor[box_b_idx + 1])) 463 i = w * h 464 u = (out_tensor[box_a_idx + 2] - out_tensor[box_a_idx]) * \ 465 (out_tensor[box_a_idx + 3] - out_tensor[box_a_idx + 1]) + \ 466 (out_tensor[box_b_idx + 2] - out_tensor[box_b_idx]) * \ 467 (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i 468 return tvm.expr.Select(u <= 0.0, 0.0, i / u) 469 470 batch_size = data.shape[0] 471 num_anchors = data.shape[1] 472 box_data_length = data.shape[2] 473 474 ib = tvm.ir_builder.create() 475 476 data = ib.buffer_ptr(data) 477 sorted_index = ib.buffer_ptr(sorted_index) 478 valid_count = ib.buffer_ptr(valid_count) 479 out = ib.buffer_ptr(out) 480 box_indices = ib.buffer_ptr(box_indices) 481 num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") 482 483 max_threads = int( 484 tvm.target.current_target(allow_none=False).max_num_threads) 485 nthread_tx = max_threads 486 nthread_bx = num_anchors // max_threads + 1 487 tx = tvm.thread_axis("threadIdx.x") 488 bx = tvm.thread_axis("blockIdx.x") 489 ib.scope_attr(tx, "thread_extent", nthread_tx) 490 ib.scope_attr(bx, "thread_extent", nthread_bx) 491 j = bx * max_threads + tx 492 493 iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold) 494 top_k = tvm.make.node("IntImm", dtype="int32", value=top_k) 495 coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start) 496 id_index = tvm.make.node("IntImm", dtype="int32", value=id_index) 497 score_index = tvm.make.node("IntImm", dtype="int32", value=score_index) 498 force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0) 499 500 with ib.for_range(0, batch_size, for_type="unroll") as i: 501 base_idx = i * num_anchors * box_data_length 502 with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)): 503 # Reorder output 504 nkeep = if_then_else( \ 505 tvm.all(top_k > 0, top_k < valid_count[i]), 506 top_k, valid_count[i]) 507 with ib.if_scope(j < nkeep): 508 with ib.for_range(0, box_data_length) as k: 509 out[(base_idx + j * box_data_length + k)] = \ 510 data[(base_idx + sorted_index[i * num_anchors + j] \ 511 * box_data_length + k)] 512 box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] 513 with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])): 514 with ib.if_scope(j < valid_count[i] - nkeep): 515 with ib.for_range(0, box_data_length) as k: 516 out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 517 box_indices[i * num_anchors + (j + nkeep)] = -1 518 # Apply nms 519 with ib.for_range(0, valid_count[i]) as k: 520 offset_k = k * box_data_length 521 with ib.if_scope(tvm.all(out[base_idx + offset_k + score_index] > 0, \ 522 tvm.any(id_index < 0, out[base_idx + offset_k + id_index] >= 0))): 523 with ib.if_scope(j < valid_count[i]): 524 offset_j = j * box_data_length 525 with ib.if_scope(tvm.all(j > k, \ 526 out[base_idx + offset_j + score_index] > 0, \ 527 tvm.any(id_index < 0, \ 528 out[base_idx + offset_j + id_index] >= 0), \ 529 tvm.any(force_suppress > 0, id_index < 0, \ 530 out[base_idx + offset_k + id_index] == \ 531 out[base_idx + offset_j + id_index]))): 532 iou = calculate_overlap(out, base_idx + offset_j + coord_start, 533 base_idx + offset_k + coord_start) 534 with ib.if_scope(iou >= iou_threshold): 535 out[base_idx + offset_j + score_index] = -1.0 536 with ib.if_scope(id_index >= 0): 537 out[base_idx + offset_j + id_index] = -1.0 538 box_indices[i * num_anchors + j] = -1 539 with ib.else_scope(): 540 with ib.if_scope(j < valid_count[i]): 541 offset_j = j * box_data_length 542 with ib.for_range(0, box_data_length) as k: 543 out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] 544 box_indices[i * num_anchors + j] = j 545 # Set invalid entry to be -1 546 with ib.if_scope(j < num_anchors - valid_count[i]): 547 with ib.for_range(0, box_data_length) as k: 548 out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0 549 box_indices[i * num_anchors + j + valid_count[i]] = -1 550 # Only return max_output_size number of valid boxes 551 num_valid_boxes[0] = 0 552 with ib.if_scope(max_output_size > 0): 553 with ib.if_scope(j < valid_count[i]): 554 offset_j = j * box_data_length 555 with ib.if_scope(out[base_idx + offset_j] >= 0): 556 with ib.if_scope(num_valid_boxes[0] == max_output_size): 557 with ib.for_range(0, box_data_length) as k: 558 out[base_idx + offset_j + k] = -1.0 559 box_indices[i * num_anchors + j] = -1 560 with ib.else_scope(): 561 num_valid_boxes[0] += 1 562 563 return ib.get() 564 565 566def invalid_to_bottom_pre(data, flag, idx): 567 """Low level IR to rearrange nms output to move all valid entries to top. 568 569 Parameters 570 ---------- 571 data: Buffer 572 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. 573 574 flag : Buffer 575 1D Buffer of flag indicating valid data with [num_anchors]. 576 577 idx : Buffer 578 1D Buffer of valid data indices with [num_anchors]. 579 580 Returns 581 ------- 582 stmt : Stmt 583 The result IR statement. 584 """ 585 batch_size = data.shape[0] 586 num_anchors = data.shape[1] 587 elem_length = data.shape[2] 588 589 ib = tvm.ir_builder.create() 590 591 data = ib.buffer_ptr(data) 592 flag = ib.buffer_ptr(flag) 593 idx = ib.buffer_ptr(idx) 594 595 max_threads = int(math.sqrt( 596 tvm.target.current_target(allow_none=False).max_num_threads)) 597 nthread_tx = max_threads 598 nthread_bx = num_anchors // max_threads + 1 599 tx = tvm.thread_axis("threadIdx.x") 600 bx = tvm.thread_axis("blockIdx.x") 601 ib.scope_attr(tx, "thread_extent", nthread_tx) 602 ib.scope_attr(bx, "thread_extent", nthread_bx) 603 j = bx * max_threads + tx 604 605 with ib.for_range(0, batch_size, for_type="unroll") as i: 606 base_idx = i * num_anchors * elem_length 607 with ib.if_scope(j < num_anchors): 608 with ib.if_scope(data[base_idx + j * elem_length] >= 0): 609 flag[i * num_anchors + j] = 1 610 idx[i * num_anchors + j] = 1 611 with ib.else_scope(): 612 flag[i * num_anchors + j] = 0 613 idx[i * num_anchors + j] = 0 614 615 with ib.if_scope(j < batch_size): 616 with ib.for_range(0, num_anchors) as k: 617 with ib.if_scope(k > 0): 618 idx[j * num_anchors + k] += idx[j * num_anchors + k - 1] 619 return ib.get() 620 621 622def invalid_to_bottom_ir(data, flag, idx, out): 623 """Low level IR to rearrange nms output to move all valid entries to top. 624 625 Parameters 626 ---------- 627 data: Buffer 628 3D Buffer with shape [batch_size, num_anchors, elem_length], output of nms. 629 630 flag : Buffer 631 1D Buffer of flag indicating valid data with [num_anchors]. 632 633 idx : Buffer 634 1D Buffer of valid data indices with [num_anchors]. 635 636 out : Buffer 637 3D Buffer of rearranged nms output with shape [batch_size, num_anchors, elem_length]. 638 639 Returns 640 ------- 641 stmt : Stmt 642 The result IR statement. 643 """ 644 batch_size = data.shape[0] 645 num_anchors = data.shape[1] 646 elem_length = data.shape[2] 647 648 ib = tvm.ir_builder.create() 649 650 data = ib.buffer_ptr(data) 651 flag = ib.buffer_ptr(flag) 652 idx = ib.buffer_ptr(idx) 653 out = ib.buffer_ptr(out) 654 655 max_threads = int(math.sqrt( 656 tvm.target.current_target(allow_none=False).max_num_threads)) 657 nthread_tx = max_threads 658 nthread_bx = num_anchors // max_threads + 1 659 tx = tvm.thread_axis("threadIdx.x") 660 bx = tvm.thread_axis("blockIdx.x") 661 ib.scope_attr(tx, "thread_extent", nthread_tx) 662 ib.scope_attr(bx, "thread_extent", nthread_bx) 663 j = bx * max_threads + tx 664 665 with ib.for_range(0, batch_size, for_type="unroll") as i: 666 base_idx = i * num_anchors * elem_length 667 with ib.if_scope(j < num_anchors): 668 with ib.for_range(0, elem_length) as k: 669 out[base_idx + j * elem_length + k] = -1.0 670 with ib.if_scope(flag[i * num_anchors + j] > 0): 671 with ib.for_range(0, elem_length) as k: 672 out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \ 673 = data[base_idx + j * elem_length + k] 674 return ib.get() 675 676 677@non_max_suppression.register(["cuda", "gpu"]) 678def non_max_suppression_gpu(data, valid_count, max_output_size=-1, 679 iou_threshold=0.5, force_suppress=False, top_k=-1, 680 coord_start=2, score_index=1, id_index=0, 681 return_indices=True, invalid_to_bottom=False): 682 """Non-maximum suppression operator for object detection. 683 684 Parameters 685 ---------- 686 data : tvm.Tensor 687 3-D tensor with shape [batch_size, num_anchors, elem_length]. 688 The last dimension should be in format of 689 [class_id, score, box_left, box_top, box_right, box_bottom]. 690 691 valid_count : tvm.Tensor 692 1-D tensor for valid number of boxes. 693 694 max_output_size : optional, int 695 Max number of output valid boxes for each instance. 696 By default all valid boxes are returned. 697 698 iou_threshold : optional, float 699 Non-maximum suppression threshold. 700 701 force_suppress : optional, boolean 702 Whether to suppress all detections regardless of class_id. 703 704 top_k : optional, int 705 Keep maximum top k detections before nms, -1 for no limit. 706 707 coord_start : required, int 708 Start index of the consecutive 4 coordinates. 709 710 score_index : optional, int 711 Index of the scores/confidence of boxes. 712 713 id_index : optional, int 714 index of the class categories, -1 to disable. 715 716 return_indices : boolean 717 Whether to return box indices in input data. 718 719 invalid_to_bottom : optional, boolean 720 Whether to move all valid bounding boxes to the top. 721 722 Returns 723 ------- 724 out : tvm.Tensor 725 3-D tensor with shape [batch_size, num_anchors, elem_length]. 726 727 Example 728 -------- 729 .. code-block:: python 730 731 # An example to use nms 732 dshape = (1, 5, 6) 733 data = tvm.placeholder(dshape, name="data") 734 valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") 735 iou_threshold = 0.7 736 force_suppress = True 737 top_k = -1 738 out = non_max_suppression(data=data, valid_count=valid_count, iou_threshold=iou_threshold, 739 force_suppress=force_supress, top_k=top_k, return_indices=False) 740 np_data = np.random.uniform(dshape) 741 np_valid_count = np.array([4]) 742 s = topi.generic.schedule_nms(out) 743 f = tvm.build(s, [data, valid_count, out], "cuda") 744 ctx = tvm.gpu(0) 745 tvm_data = tvm.nd.array(np_data, ctx) 746 tvm_valid_count = tvm.nd.array(np_valid_count, ctx) 747 tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) 748 f(tvm_data, tvm_valid_count, tvm_out) 749 """ 750 batch_size = data.shape[0] 751 num_anchors = data.shape[1] 752 753 valid_count_dtype = "int32" 754 valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, 755 "valid_count_buf", data_alignment=4) 756 score_axis = score_index 757 score_shape = (batch_size, num_anchors) 758 score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE) 759 sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False) 760 761 sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype, 762 "sort_tensor_buf", data_alignment=8) 763 764 data_buf = api.decl_buffer( 765 data.shape, data.dtype, "data_buf", data_alignment=8) 766 767 out_buf = api.decl_buffer( 768 data.shape, data.dtype, "out_buf", data_alignment=8) 769 770 out, box_indices = \ 771 tvm.extern([data.shape, score_shape], 772 [data, sort_tensor, valid_count], 773 lambda ins, outs: nms_ir( 774 ins[0], ins[1], ins[2], outs[0], outs[1], 775 max_output_size, iou_threshold, force_suppress, 776 top_k, coord_start, id_index, score_index), 777 dtype=[data.dtype, "int32"], 778 in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], 779 name="nms", 780 tag="nms") 781 782 if return_indices: 783 return box_indices 784 785 if invalid_to_bottom: 786 output_buf = api.decl_buffer( 787 data.shape, data.dtype, "output_buf", data_alignment=8) 788 temp_flag_buf = api.decl_buffer( 789 score_shape, valid_count_dtype, "temp_flag", data_alignment=8) 790 temp_idx_buf = api.decl_buffer( 791 score_shape, valid_count_dtype, "temp_idx", data_alignment=8) 792 temp_flag, temp_idx = tvm.extern([score_shape, score_shape], [out], 793 lambda ins, outs: invalid_to_bottom_pre( 794 ins[0], outs[0], outs[1]), 795 dtype=["int32", "int32"], 796 in_buffers=[out_buf], 797 out_buffers=[temp_flag_buf, temp_idx_buf], 798 name="invalid_to_bottom_phase_one") 799 800 output = tvm.extern([data.shape], [out, temp_flag, temp_idx], 801 lambda ins, outs: invalid_to_bottom_ir( 802 ins[0], ins[1], ins[2], outs[0]), 803 dtype=[data.dtype], 804 in_buffers=[out_buf, temp_flag_buf, temp_idx_buf], 805 out_buffers=[output_buf], 806 name="invalid_to_bottom", 807 tag="invalid_to_bottom") 808 return output 809 810 return out 811