1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument 18"""Argsort operator """ 19import tvm 20 21from tvm import api 22from ..sort import argsort, topk 23from ..math import identity 24from ..transform import strided_slice 25from .. import generic 26from .. import tag 27 28def _schedule_sort(outs): 29 """Schedule for argsort operator. 30 31 Parameters 32 ---------- 33 outs: Array of Tensor 34 The computation graph description of argsort 35 in the format of an array of tensors. 36 37 Returns 38 ------- 39 s: Schedule 40 The computation schedule for the op. 41 """ 42 outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs 43 s = tvm.create_schedule([x.op for x in outs]) 44 scheduled_ops = [] 45 from .injective import schedule_injective_from_existing 46 def traverse(op): 47 if tag.is_injective(op.tag): 48 schedule_injective_from_existing(s, op.output(0)) 49 for tensor in op.input_tensors: 50 if tensor.op.input_tensors and tensor.op not in scheduled_ops: 51 traverse(tensor.op) 52 scheduled_ops.append(op) 53 for out in outs: 54 traverse(out.op) 55 return s 56 57def sort_ir(data, values_out, axis, is_ascend, indices_out=None): 58 """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. 59 60 Parameters 61 ---------- 62 data: Buffer 63 Buffer of input data. Data will be sorted in place. 64 65 output : Buffer 66 Output buffer of indicies of sorted tensor with same shape as data. 67 68 axis : Int 69 Axis long which to sort the input tensor. 70 71 is_ascend : Boolean 72 Whether to sort in ascending or descending order. 73 74 Returns 75 ------- 76 stmt : Stmt 77 The result IR statement. 78 """ 79 axis_mul_before = 1 80 axis_mul_after = 1 81 shape = data.shape 82 if axis < 0: 83 axis = len(shape) + axis 84 for i, value in enumerate(shape, 0): 85 if i < axis: 86 axis_mul_before *= value 87 elif i > axis: 88 axis_mul_after *= value 89 max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) 90 ib = tvm.ir_builder.create() 91 data = ib.buffer_ptr(data) 92 values_out = ib.buffer_ptr(values_out) 93 if indices_out is not None: 94 indices_out = ib.buffer_ptr(indices_out) 95 nthread_tx = max_threads 96 nthread_bx = shape[axis] // max_threads + 1 97 98 tx = tvm.thread_axis("threadIdx.x") 99 bx = tvm.thread_axis("vthread") 100 ib.scope_attr(tx, "thread_extent", nthread_tx) 101 ib.scope_attr(bx, "virtual_thread", nthread_bx) 102 tid = bx * nthread_tx + tx 103 temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local") 104 if indices_out is not None: 105 temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local") 106 107 with ib.for_range(0, axis_mul_before) as i: 108 with ib.for_range(0, axis_mul_after) as j: 109 base_idx = i * shape[axis] * axis_mul_after + j 110 with ib.if_scope(tid < shape[axis]): 111 values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after] 112 if indices_out is not None: 113 indices_out[base_idx + tid * axis_mul_after] = \ 114 tvm.generic.cast(tid, indices_out.dtype) 115 ib.emit(tvm.make.Call(None, 'tvm_storage_sync', 116 tvm.convert(['shared']), 117 tvm.expr.Call.Intrinsic, None, 0)) 118 idxd = tvm.indexdiv 119 idxm = tvm.indexmod 120 121 with ib.for_range(0, axis_mul_before) as i: 122 with ib.for_range(0, axis_mul_after) as j: 123 current_sort_num = shape[axis] 124 base_idx = i * shape[axis] * axis_mul_after + j 125 # OddEvenTransposeSort 126 with ib.for_range(0, current_sort_num) as k: 127 with ib.if_scope(tid < idxd(current_sort_num + 1, 2)): 128 offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after 129 if is_ascend: 130 cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num, 131 values_out[offset] > values_out[offset + axis_mul_after]) 132 else: 133 cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num, 134 values_out[offset] < values_out[offset + axis_mul_after]) 135 with ib.if_scope(cond): 136 temp_data[0] = values_out[offset] 137 values_out[offset] = values_out[offset + axis_mul_after] 138 values_out[offset + axis_mul_after] = temp_data[0] 139 if indices_out is not None: 140 temp_index[0] = indices_out[offset] 141 indices_out[offset] = indices_out[offset + axis_mul_after] 142 indices_out[offset + axis_mul_after] = temp_index[0] 143 ib.emit(tvm.make.Call(None, 'tvm_storage_sync', 144 tvm.convert(['shared']), 145 tvm.expr.Call.Intrinsic, None, 0)) 146 147 return ib.get() 148 149 150def sort_nms_ir(data, valid_count, output, axis, is_ascend): 151 """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. 152 153 Parameters 154 ---------- 155 data: Buffer 156 Buffer of input data. 157 158 valid_count : Buffer 159 1D Buffer of number of valid number of boxes. 160 161 output : Buffer 162 Output buffer of indicies of sorted tensor with same shape as data. 163 164 axis : Int 165 Axis long which to sort the input tensor. 166 167 is_ascend : Boolean 168 Whether to sort in ascending or descending order. 169 170 Returns 171 ------- 172 stmt : Stmt 173 The result IR statement. 174 """ 175 176 size = 1 177 axis_mul_before = 1 178 axis_mul_after = 1 179 shape = data.shape 180 if axis < 0: 181 axis = len(shape) + axis 182 for i, value in enumerate(shape, 0): 183 size *= value 184 if i < axis: 185 axis_mul_before *= value 186 elif i > axis: 187 axis_mul_after *= value 188 max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) 189 ib = tvm.ir_builder.create() 190 data = ib.buffer_ptr(data) 191 valid_count = ib.buffer_ptr(valid_count) 192 output = ib.buffer_ptr(output) 193 nthread_tx = max_threads 194 nthread_bx = size // max_threads + 1 195 tx = tvm.thread_axis("threadIdx.x") 196 bx = tvm.thread_axis("vthread") 197 ib.scope_attr(tx, "thread_extent", nthread_tx) 198 ib.scope_attr(bx, "virtual_thread", nthread_bx) 199 tid = bx * nthread_tx + tx 200 temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") 201 temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") 202 is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend) 203 204 idxd = tvm.indexdiv 205 idxm = tvm.indexmod 206 207 with ib.for_range(0, axis_mul_before) as i: 208 with ib.for_range(0, axis_mul_after) as j: 209 current_sort_num = valid_count[i * axis_mul_after + j] 210 base_idx = i * shape[axis] * axis_mul_after + j 211 with ib.if_scope(tid < shape[axis]): 212 output[base_idx + tid * axis_mul_after] = tid 213 # OddEvenTransposeSort 214 with ib.for_range(0, current_sort_num) as k: 215 with ib.if_scope(tid < idxd(current_sort_num + 1, 2)): 216 offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after 217 with ib.if_scope(tvm.all(is_ascend == 1, \ 218 2 * tid + idxm(k, 2) + 1 < current_sort_num, \ 219 data[offset] > data[offset + axis_mul_after])): 220 temp_data[0] = data[offset] 221 data[offset] = data[offset + axis_mul_after] 222 data[offset + axis_mul_after] = temp_data[0] 223 temp_index[0] = output[offset] 224 output[offset] = output[offset + axis_mul_after] 225 output[offset + axis_mul_after] = temp_index[0] 226 with ib.if_scope(tvm.all(is_ascend == 0, \ 227 2 * tid + idxm(k, 2) + 1 < current_sort_num, \ 228 data[offset] < data[offset + axis_mul_after])): 229 temp_data[0] = data[offset] 230 data[offset] = data[offset + axis_mul_after] 231 data[offset + axis_mul_after] = temp_data[0] 232 temp_index[0] = output[offset] 233 output[offset] = output[offset + axis_mul_after] 234 output[offset + axis_mul_after] = temp_index[0] 235 ib.emit(tvm.make.Call(None, 'tvm_storage_sync', 236 tvm.convert(['shared']), 237 tvm.expr.Call.Intrinsic, None, 0)) 238 239 return ib.get() 240 241@argsort.register(["cuda", "gpu"]) 242def argsort_gpu(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"): 243 """Performs sorting along the given axis and returns an array of indicies 244 having same shape as an input array that index data in sorted order. 245 246 Parameters 247 ---------- 248 data: tvm.Tensor 249 The input array. 250 251 valid_count : tvm.Tensor, optional 252 The number of valid elements to be sorted. 253 254 axis : int, optional 255 Axis long which to sort the input tensor. 256 257 is_ascend : boolean, optional 258 Whether to sort in ascending or descending order. 259 260 dtype : string, optional 261 DType of the output indices. 262 263 Returns 264 ------- 265 out : tvm.Tensor 266 The output of this function. 267 """ 268 if valid_count is not None: 269 sorted_data = identity(data) 270 sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", 271 data_alignment=8) 272 valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, 273 "valid_count_buf", data_alignment=4) 274 out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) 275 out = tvm.extern([data.shape], 276 [sorted_data, valid_count], 277 lambda ins, outs: sort_nms_ir( 278 ins[0], ins[1], outs[0], axis, is_ascend), 279 dtype="int32", 280 in_buffers=[sorted_data_buf, valid_count_buf], 281 out_buffers=[out_buf], 282 name="argsort_nms_gpu", 283 tag="argsort_nms_gpu") 284 else: 285 value_buf = api.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8) 286 indices_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) 287 out = tvm.extern([data.shape, data.shape], 288 [data], 289 lambda ins, outs: sort_ir( 290 ins[0], outs[0], axis, is_ascend, indices_out=outs[1]), 291 out_buffers=[value_buf, indices_buf], 292 name="argsort_gpu", 293 tag="argsort_gpu")[1] 294 return out 295 296@generic.schedule_argsort.register(["cuda", "gpu"]) 297def schedule_argsort(outs): 298 """Schedule for argsort operator. 299 300 Parameters 301 ---------- 302 outs: Array of Tensor 303 The computation graph description of argsort 304 in the format of an array of tensors. 305 306 Returns 307 ------- 308 s: Schedule 309 The computation schedule for the op. 310 """ 311 return _schedule_sort(outs) 312 313@topk.register(["cuda", "gpu"]) 314def topk_gpu(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): 315 """Get the top k elements in an input tensor along the given axis. 316 317 Parameters 318 ---------- 319 data : tvm.Tensor 320 The input tensor. 321 322 k : int, optional 323 Number of top elements to select. Return all elements if k < 1. 324 325 axis : int, optional 326 Axis long which to sort the input tensor. 327 328 ret_type: str, optional 329 The return type [both, values, indices]. 330 "both": return both top k data and indices. 331 "values": return top k data only. 332 "indices": return top k indices only. 333 334 is_ascend : boolean, optional 335 Whether to sort in ascending or descending order. 336 337 dtype : string, optional 338 The data type of the indices output. 339 340 Returns 341 ------- 342 out : tvm.Tensor or List[tvm.Tensor] 343 The computed result. 344 """ 345 assert ret_type in ["both", "values", "indices"] 346 ndim = len(data.shape) 347 axis = axis + ndim if axis < 0 else axis 348 assert 0 <= axis < ndim 349 values_buf = api.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8) 350 indices_buf = api.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8) 351 if ret_type == "values": 352 output = tvm.extern([data.shape], 353 [data], 354 lambda ins, outs: sort_ir( 355 ins[0], outs[0], axis, is_ascend), 356 out_buffers=[values_buf], 357 name="topk_gpu", 358 tag="topk_gpu") 359 else: 360 output = tvm.extern([data.shape, data.shape], 361 [data], 362 lambda ins, outs: sort_ir( 363 ins[0], outs[0], axis, is_ascend, indices_out=outs[1]), 364 out_buffers=[values_buf, indices_buf], 365 name="topk_gpu", 366 tag="topk_gpu") 367 if k < 1: 368 if ret_type == "indices": 369 return output[1] 370 return output 371 beg = [0] * ndim 372 end = [] 373 for i in range(ndim): 374 if i == axis: 375 end.append(k) 376 else: 377 end.append(data.shape[i]) 378 if ret_type == "both": 379 values_out, indices_out = output 380 values_out = strided_slice(values_out, beg, end) 381 indices_out = strided_slice(indices_out, beg, end) 382 output = [values_out, indices_out] 383 elif ret_type == "values": 384 output = [strided_slice(output, beg, end)] 385 else: # ret_type == "indices" 386 indices_out = output[1] 387 output = [strided_slice(indices_out, beg, end)] 388 return output 389 390 391@generic.schedule_topk.register(["cuda", "gpu"]) 392def schedule_topk(outs): 393 """Schedule for argsort operator. 394 395 Parameters 396 ---------- 397 outs: Array of Tensor 398 The computation graph description of argsort 399 in the format of an array of tensors. 400 401 Returns 402 ------- 403 s: Schedule 404 The computation schedule for the op. 405 """ 406 return _schedule_sort(outs) 407