1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument
18"""Argsort operator """
19import tvm
20
21from tvm import api
22from ..sort import argsort, topk
23from ..math import identity
24from ..transform import strided_slice
25from .. import generic
26from .. import tag
27
28def _schedule_sort(outs):
29    """Schedule for argsort operator.
30
31    Parameters
32    ----------
33    outs: Array of Tensor
34        The computation graph description of argsort
35        in the format of an array of tensors.
36
37    Returns
38    -------
39    s: Schedule
40      The computation schedule for the op.
41    """
42    outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
43    s = tvm.create_schedule([x.op for x in outs])
44    scheduled_ops = []
45    from .injective import schedule_injective_from_existing
46    def traverse(op):
47        if tag.is_injective(op.tag):
48            schedule_injective_from_existing(s, op.output(0))
49        for tensor in op.input_tensors:
50            if tensor.op.input_tensors and tensor.op not in scheduled_ops:
51                traverse(tensor.op)
52        scheduled_ops.append(op)
53    for out in outs:
54        traverse(out.op)
55    return s
56
57def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
58    """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
59
60    Parameters
61    ----------
62    data: Buffer
63        Buffer of input data. Data will be sorted in place.
64
65    output : Buffer
66        Output buffer of indicies of sorted tensor with same shape as data.
67
68    axis : Int
69        Axis long which to sort the input tensor.
70
71    is_ascend : Boolean
72        Whether to sort in ascending or descending order.
73
74    Returns
75    -------
76    stmt : Stmt
77        The result IR statement.
78    """
79    axis_mul_before = 1
80    axis_mul_after = 1
81    shape = data.shape
82    if axis < 0:
83        axis = len(shape) + axis
84    for i, value in enumerate(shape, 0):
85        if i < axis:
86            axis_mul_before *= value
87        elif i > axis:
88            axis_mul_after *= value
89    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
90    ib = tvm.ir_builder.create()
91    data = ib.buffer_ptr(data)
92    values_out = ib.buffer_ptr(values_out)
93    if indices_out is not None:
94        indices_out = ib.buffer_ptr(indices_out)
95    nthread_tx = max_threads
96    nthread_bx = shape[axis] // max_threads + 1
97
98    tx = tvm.thread_axis("threadIdx.x")
99    bx = tvm.thread_axis("vthread")
100    ib.scope_attr(tx, "thread_extent", nthread_tx)
101    ib.scope_attr(bx, "virtual_thread", nthread_bx)
102    tid = bx * nthread_tx + tx
103    temp_data = ib.allocate(values_out.dtype, (1,), name="temp_data", scope="local")
104    if indices_out is not None:
105        temp_index = ib.allocate(indices_out.dtype, (1,), name="temp_index", scope="local")
106
107    with ib.for_range(0, axis_mul_before) as i:
108        with ib.for_range(0, axis_mul_after) as j:
109            base_idx = i * shape[axis] * axis_mul_after + j
110            with ib.if_scope(tid < shape[axis]):
111                values_out[base_idx + tid * axis_mul_after] = data[base_idx + tid * axis_mul_after]
112                if indices_out is not None:
113                    indices_out[base_idx + tid * axis_mul_after] = \
114                        tvm.generic.cast(tid, indices_out.dtype)
115    ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
116                          tvm.convert(['shared']),
117                          tvm.expr.Call.Intrinsic, None, 0))
118    idxd = tvm.indexdiv
119    idxm = tvm.indexmod
120
121    with ib.for_range(0, axis_mul_before) as i:
122        with ib.for_range(0, axis_mul_after) as j:
123            current_sort_num = shape[axis]
124            base_idx = i * shape[axis] * axis_mul_after + j
125            # OddEvenTransposeSort
126            with ib.for_range(0, current_sort_num) as k:
127                with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
128                    offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
129                    if is_ascend:
130                        cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num,
131                                       values_out[offset] > values_out[offset + axis_mul_after])
132                    else:
133                        cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num,
134                                       values_out[offset] < values_out[offset + axis_mul_after])
135                    with ib.if_scope(cond):
136                        temp_data[0] = values_out[offset]
137                        values_out[offset] = values_out[offset + axis_mul_after]
138                        values_out[offset + axis_mul_after] = temp_data[0]
139                        if indices_out is not None:
140                            temp_index[0] = indices_out[offset]
141                            indices_out[offset] = indices_out[offset + axis_mul_after]
142                            indices_out[offset + axis_mul_after] = temp_index[0]
143                ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
144                                      tvm.convert(['shared']),
145                                      tvm.expr.Call.Intrinsic, None, 0))
146
147    return ib.get()
148
149
150def sort_nms_ir(data, valid_count, output, axis, is_ascend):
151    """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
152
153    Parameters
154    ----------
155    data: Buffer
156        Buffer of input data.
157
158    valid_count : Buffer
159        1D Buffer of number of valid number of boxes.
160
161    output : Buffer
162        Output buffer of indicies of sorted tensor with same shape as data.
163
164    axis : Int
165        Axis long which to sort the input tensor.
166
167    is_ascend : Boolean
168        Whether to sort in ascending or descending order.
169
170    Returns
171    -------
172    stmt : Stmt
173        The result IR statement.
174    """
175
176    size = 1
177    axis_mul_before = 1
178    axis_mul_after = 1
179    shape = data.shape
180    if axis < 0:
181        axis = len(shape) + axis
182    for i, value in enumerate(shape, 0):
183        size *= value
184        if i < axis:
185            axis_mul_before *= value
186        elif i > axis:
187            axis_mul_after *= value
188    max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
189    ib = tvm.ir_builder.create()
190    data = ib.buffer_ptr(data)
191    valid_count = ib.buffer_ptr(valid_count)
192    output = ib.buffer_ptr(output)
193    nthread_tx = max_threads
194    nthread_bx = size // max_threads + 1
195    tx = tvm.thread_axis("threadIdx.x")
196    bx = tvm.thread_axis("vthread")
197    ib.scope_attr(tx, "thread_extent", nthread_tx)
198    ib.scope_attr(bx, "virtual_thread", nthread_bx)
199    tid = bx * nthread_tx + tx
200    temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
201    temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
202    is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend)
203
204    idxd = tvm.indexdiv
205    idxm = tvm.indexmod
206
207    with ib.for_range(0, axis_mul_before) as i:
208        with ib.for_range(0, axis_mul_after) as j:
209            current_sort_num = valid_count[i * axis_mul_after + j]
210            base_idx = i * shape[axis] * axis_mul_after + j
211            with ib.if_scope(tid < shape[axis]):
212                output[base_idx + tid * axis_mul_after] = tid
213            # OddEvenTransposeSort
214            with ib.for_range(0, current_sort_num) as k:
215                with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
216                    offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
217                    with ib.if_scope(tvm.all(is_ascend == 1, \
218                                             2 * tid + idxm(k, 2) + 1 < current_sort_num, \
219                                             data[offset] > data[offset + axis_mul_after])):
220                        temp_data[0] = data[offset]
221                        data[offset] = data[offset + axis_mul_after]
222                        data[offset + axis_mul_after] = temp_data[0]
223                        temp_index[0] = output[offset]
224                        output[offset] = output[offset + axis_mul_after]
225                        output[offset + axis_mul_after] = temp_index[0]
226                    with ib.if_scope(tvm.all(is_ascend == 0, \
227                                             2 * tid + idxm(k, 2) + 1 < current_sort_num, \
228                                             data[offset] < data[offset + axis_mul_after])):
229                        temp_data[0] = data[offset]
230                        data[offset] = data[offset + axis_mul_after]
231                        data[offset + axis_mul_after] = temp_data[0]
232                        temp_index[0] = output[offset]
233                        output[offset] = output[offset + axis_mul_after]
234                        output[offset + axis_mul_after] = temp_index[0]
235                ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
236                                      tvm.convert(['shared']),
237                                      tvm.expr.Call.Intrinsic, None, 0))
238
239    return ib.get()
240
241@argsort.register(["cuda", "gpu"])
242def argsort_gpu(data, valid_count=None, axis=-1, is_ascend=1, dtype="float32"):
243    """Performs sorting along the given axis and returns an array of indicies
244    having same shape as an input array that index data in sorted order.
245
246    Parameters
247    ----------
248    data: tvm.Tensor
249        The input array.
250
251    valid_count : tvm.Tensor, optional
252        The number of valid elements to be sorted.
253
254    axis : int, optional
255        Axis long which to sort the input tensor.
256
257    is_ascend : boolean, optional
258        Whether to sort in ascending or descending order.
259
260    dtype : string, optional
261        DType of the output indices.
262
263    Returns
264    -------
265    out : tvm.Tensor
266        The output of this function.
267    """
268    if valid_count is not None:
269        sorted_data = identity(data)
270        sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf",
271                                          data_alignment=8)
272        valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
273                                          "valid_count_buf", data_alignment=4)
274        out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4)
275        out = tvm.extern([data.shape],
276                         [sorted_data, valid_count],
277                         lambda ins, outs: sort_nms_ir(
278                             ins[0], ins[1], outs[0], axis, is_ascend),
279                         dtype="int32",
280                         in_buffers=[sorted_data_buf, valid_count_buf],
281                         out_buffers=[out_buf],
282                         name="argsort_nms_gpu",
283                         tag="argsort_nms_gpu")
284    else:
285        value_buf = api.decl_buffer(data.shape, data.dtype, "value_buf", data_alignment=8)
286        indices_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
287        out = tvm.extern([data.shape, data.shape],
288                         [data],
289                         lambda ins, outs: sort_ir(
290                             ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
291                         out_buffers=[value_buf, indices_buf],
292                         name="argsort_gpu",
293                         tag="argsort_gpu")[1]
294    return out
295
296@generic.schedule_argsort.register(["cuda", "gpu"])
297def schedule_argsort(outs):
298    """Schedule for argsort operator.
299
300    Parameters
301    ----------
302    outs: Array of Tensor
303        The computation graph description of argsort
304        in the format of an array of tensors.
305
306    Returns
307    -------
308    s: Schedule
309      The computation schedule for the op.
310    """
311    return _schedule_sort(outs)
312
313@topk.register(["cuda", "gpu"])
314def topk_gpu(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
315    """Get the top k elements in an input tensor along the given axis.
316
317    Parameters
318    ----------
319    data : tvm.Tensor
320        The input tensor.
321
322    k : int, optional
323        Number of top elements to select. Return all elements if k < 1.
324
325    axis : int, optional
326        Axis long which to sort the input tensor.
327
328    ret_type: str, optional
329        The return type [both, values, indices].
330        "both": return both top k data and indices.
331        "values": return top k data only.
332        "indices": return top k indices only.
333
334    is_ascend : boolean, optional
335        Whether to sort in ascending or descending order.
336
337    dtype : string, optional
338        The data type of the indices output.
339
340    Returns
341    -------
342    out : tvm.Tensor or List[tvm.Tensor]
343        The computed result.
344    """
345    assert ret_type in ["both", "values", "indices"]
346    ndim = len(data.shape)
347    axis = axis + ndim if axis < 0 else axis
348    assert 0 <= axis < ndim
349    values_buf = api.decl_buffer(data.shape, data.dtype, "values_buf", data_alignment=8)
350    indices_buf = api.decl_buffer(data.shape, dtype, "indices_buf", data_alignment=8)
351    if ret_type == "values":
352        output = tvm.extern([data.shape],
353                            [data],
354                            lambda ins, outs: sort_ir(
355                                ins[0], outs[0], axis, is_ascend),
356                            out_buffers=[values_buf],
357                            name="topk_gpu",
358                            tag="topk_gpu")
359    else:
360        output = tvm.extern([data.shape, data.shape],
361                            [data],
362                            lambda ins, outs: sort_ir(
363                                ins[0], outs[0], axis, is_ascend, indices_out=outs[1]),
364                            out_buffers=[values_buf, indices_buf],
365                            name="topk_gpu",
366                            tag="topk_gpu")
367    if k < 1:
368        if ret_type == "indices":
369            return output[1]
370        return output
371    beg = [0] * ndim
372    end = []
373    for i in range(ndim):
374        if i == axis:
375            end.append(k)
376        else:
377            end.append(data.shape[i])
378    if ret_type == "both":
379        values_out, indices_out = output
380        values_out = strided_slice(values_out, beg, end)
381        indices_out = strided_slice(indices_out, beg, end)
382        output = [values_out, indices_out]
383    elif ret_type == "values":
384        output = [strided_slice(output, beg, end)]
385    else: # ret_type == "indices"
386        indices_out = output[1]
387        output = [strided_slice(indices_out, beg, end)]
388    return output
389
390
391@generic.schedule_topk.register(["cuda", "gpu"])
392def schedule_topk(outs):
393    """Schedule for argsort operator.
394
395    Parameters
396    ----------
397    outs: Array of Tensor
398        The computation graph description of argsort
399        in the format of an array of tensors.
400
401    Returns
402    -------
403    s: Schedule
404      The computation schedule for the op.
405    """
406    return _schedule_sort(outs)
407