1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Test code for broadcasting operators."""
18import numpy as np
19import pytest
20import tvm
21from tvm import te
22from tvm import topi
23import tvm.topi.testing
24from tvm.contrib.nvcc import have_fp16
25
26import tvm.testing
27
28
29def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
30    A = te.placeholder(shape=in_shape, name="A")
31    B = topi.expand_dims(A, axis, num_newaxis)
32
33    def check_device(device, ctx):
34        print("Running on target: %s" % device)
35        with tvm.target.Target(device):
36            s = tvm.topi.testing.get_broadcast_schedule(device)(B)
37        foo = tvm.build(s, [A, B], device, name="expand_dims")
38        data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
39        out_npy = data_npy.reshape(out_shape)
40        data_nd = tvm.nd.array(data_npy, ctx)
41        out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
42        foo(data_nd, out_nd)
43        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
44
45    for device, ctx in tvm.testing.enabled_targets():
46        check_device(device, ctx)
47
48
49def verify_reinterpret(in_shape, in_dtype, out_dtype, generator):
50    A = te.placeholder(shape=in_shape, name="A", dtype=in_dtype)
51    B = topi.reinterpret(A, out_dtype)
52
53    def check_device(device, ctx):
54        if in_dtype == "float16" and device == "cuda" and not have_fp16(ctx.compute_version):
55            print("Skip because %s does not have fp16 support" % device)
56            return
57        print("Running on target: %s" % device)
58        with tvm.target.Target(device):
59            s = tvm.topi.testing.get_elemwise_schedule(device)(B)
60        foo = tvm.build(s, [A, B], device, name="reinterpret")
61        data_npy = generator(in_shape).astype(in_dtype)
62        out_npy = data_npy.view(B.dtype)
63        data_nd = tvm.nd.array(data_npy, ctx)
64        out_nd = tvm.nd.array(np.empty(in_shape).astype(B.dtype), ctx)
65        foo(data_nd, out_nd)
66        np.testing.assert_equal(out_nd.asnumpy(), out_npy)
67
68    for device, ctx in tvm.testing.enabled_targets():
69        check_device(device, ctx)
70
71
72def verify_transpose(in_shape, axes):
73    A = te.placeholder(shape=in_shape, name="A")
74    B = topi.transpose(A, axes)
75
76    def check_device(device, ctx):
77        print("Running on target: %s" % device)
78        with tvm.target.Target(device):
79            s = tvm.topi.testing.get_injective_schedule(device)(B)
80        foo = tvm.build(s, [A, B], device, name="transpose")
81        data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
82        out_npy = data_npy.transpose(axes)
83        data_nd = tvm.nd.array(data_npy, ctx)
84        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype)
85        foo(data_nd, out_nd)
86        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
87
88    for device, ctx in tvm.testing.enabled_targets():
89        check_device(device, ctx)
90
91
92def verify_reshape(src_shape, dst_shape):
93    A = te.placeholder(shape=src_shape, name="A")
94    B = topi.reshape(A, dst_shape)
95
96    def check_device(device, ctx):
97        print("Running on target: %s" % device)
98        with tvm.target.Target(device):
99            s = tvm.topi.testing.get_injective_schedule(device)(B)
100        foo = tvm.build(s, [A, B], device, name="reshape")
101        data_npy = np.random.normal(size=src_shape).astype(A.dtype)
102        out_npy = np.reshape(data_npy, newshape=dst_shape)
103        data_nd = tvm.nd.array(data_npy, ctx)
104        out_nd = tvm.nd.empty(dst_shape, ctx=ctx, dtype=B.dtype)
105        foo(data_nd, out_nd)
106        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
107
108    for device, ctx in tvm.testing.enabled_targets():
109        check_device(device, ctx)
110
111
112def verify_squeeze(src_shape, axis):
113    A = te.placeholder(shape=src_shape, name="A")
114    B = topi.squeeze(A, axis=axis)
115
116    def check_device(device, ctx):
117        print("Running on target: %s" % device)
118        with tvm.target.Target(device):
119            s = tvm.topi.testing.get_injective_schedule(device)(B)
120
121        foo = tvm.build(s, [A, B], device, name="squeeze")
122        data_npy = np.random.normal(size=src_shape).astype(A.dtype)
123        out_npy = np.squeeze(data_npy, axis=axis)
124        data_nd = tvm.nd.array(data_npy, ctx)
125        out_nd_shape = out_npy.shape
126        out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype)
127        foo(data_nd, out_nd)
128        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
129
130    for device, ctx in tvm.testing.enabled_targets():
131        check_device(device, ctx)
132
133
134def verify_concatenate(shapes, axis):
135    def get_concat_schedule(target):
136        schedule_map = {
137            "cpu": topi.x86.schedule_concatenate,
138            "arm_cpu": topi.arm_cpu.schedule_concatenate,
139        }
140        if isinstance(target, str):
141            target = tvm.target.Target(target)
142        for key in target.keys:
143            if key in schedule_map:
144                return schedule_map[key]
145        return tvm.topi.testing.get_injective_schedule(target)
146
147    tensor_l = []
148    for i, shape in enumerate(shapes):
149        tensor_l.append(te.placeholder(shape, name="A" + str(i)))
150    out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis)
151
152    def check_device(device, ctx):
153        print("Running on target: %s" % device)
154        with tvm.target.Target(device):
155            s = get_concat_schedule(device)(out_tensor)
156
157        foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
158        data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
159        out_npy = np.concatenate(data_npys, axis=axis)
160        data_nds = [tvm.nd.array(data_npy, ctx) for data_npy in data_npys]
161        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=out_tensor.dtype)
162        foo(*(data_nds + [out_nd]))
163        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
164
165    for device, ctx in tvm.testing.enabled_targets():
166        check_device(device, ctx)
167
168
169def verify_stack(shapes, axis):
170    tensor_l = []
171    for i, shape in enumerate(shapes):
172        tensor_l.append(te.placeholder(shape, name="A" + str(i)))
173    out_tensor = topi.stack(tensor_l, axis)
174
175    def check_device(device, ctx):
176        print("Running on target: %s" % device)
177        with tvm.target.Target(device):
178            s = tvm.topi.testing.get_broadcast_schedule(device)(out_tensor)
179
180        foo = tvm.build(s, tensor_l + [out_tensor], device, name="stack")
181        data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
182        out_npy = np.stack(data_npys, axis=axis)
183        data_nds = [tvm.nd.array(data_npy, ctx) for data_npy in data_npys]
184        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=out_tensor.dtype)
185        foo(*(data_nds + [out_nd]))
186        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
187
188    for device, ctx in tvm.testing.enabled_targets():
189        check_device(device, ctx)
190
191
192def verify_split(src_shape, indices_or_sections, axis):
193    A = te.placeholder(shape=src_shape, name="A")
194    tensor_l = topi.split(A, indices_or_sections, axis=axis)
195
196    def check_device(device, ctx):
197        print("Running on target: %s" % device)
198        with tvm.target.Target(device):
199            s = tvm.topi.testing.get_injective_schedule(device)(tensor_l)
200
201        foo = tvm.build(s, [A] + list(tensor_l), device, name="split")
202        data_npy = np.random.normal(size=src_shape).astype(A.dtype)
203        out_npys = np.split(data_npy, indices_or_sections, axis=axis)
204        data_nd = tvm.nd.array(data_npy, ctx)
205        out_nds = [
206            tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=tensor_l[0].dtype) for out_npy in out_npys
207        ]
208        foo(*([data_nd] + out_nds))
209        for out_nd, out_npy in zip(out_nds, out_npys):
210            tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
211
212    for device, ctx in tvm.testing.enabled_targets():
213        check_device(device, ctx)
214
215
216def verify_expand_like(in_shape, out_shape, axis):
217    A = te.placeholder(shape=in_shape, name="A")
218    B = te.placeholder(shape=out_shape, name="B")
219    C = topi.expand_like(A, B, axis)
220    s = te.create_schedule([C.op])
221
222    def check_device(device):
223        print("Running on target: %s" % device)
224
225        ctx = tvm.context(device, 0)
226        f = tvm.build(s, [A, B, C], device, name="expand_like")
227        input = np.random.uniform(size=in_shape).astype(A.dtype)
228        tvm_input = tvm.nd.array(input, ctx)
229
230        odim = len(out_shape)
231        real_axis = [x if x >= 0 else x + odim for x in axis]
232        real_axis = sorted(real_axis)
233        for x in real_axis:
234            input = np.expand_dims(input, x).astype(A.dtype)
235        for x in real_axis:
236            input = np.concatenate([input] * out_shape[x], axis=x).astype(A.dtype)
237        assert input.shape == out_shape
238
239        tvm_shape_like = tvm.nd.array(np.zeros(out_shape).astype(B.dtype), ctx)
240        out = tvm.nd.array(np.zeros(out_shape).astype(A.dtype), ctx)
241        f(tvm_input, tvm_shape_like, out)
242        tvm.testing.assert_allclose(out.asnumpy(), input)
243
244    for device in ["llvm"]:
245        check_device(device)
246
247
248def verify_flip(in_shape, axis):
249    A = te.placeholder(shape=in_shape, name="A")
250    B = topi.flip(A, axis) + 1
251
252    def check_device(device):
253        ctx = tvm.context(device, 0)
254        if not tvm.testing.device_enabled(device):
255            print("Skip because %s is not enabled" % device)
256            return
257        print("Running on target: %s" % device)
258        with tvm.target.Target(device):
259            s = tvm.topi.testing.get_injective_schedule(device)(B)
260
261        foo = tvm.build(s, [A, B], device, name="reverse")
262        x_np = np.random.uniform(size=in_shape).astype(A.dtype)
263        out_npy = np.flip(x_np, axis) + 1
264        data_nd = tvm.nd.array(x_np, ctx)
265        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
266        foo(data_nd, out_nd)
267        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
268
269    for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]:
270        check_device(device)
271
272
273@tvm.testing.uses_gpu
274def test_reverse_sequence():
275    def verify_reverse_sequence(in_data, seq_lengths, batch_axis, seq_axis, ref_res):
276        seq_lengths = np.array(seq_lengths).astype("int32")
277        A = te.placeholder(shape=in_data.shape, name="A", dtype=str(in_data.dtype))
278        B = te.placeholder(shape=seq_lengths.shape, name="B", dtype=str(seq_lengths.dtype))
279        C = topi.reverse_sequence(A, B, seq_axis, batch_axis)
280
281        def check_device(device, ctx):
282            print("Running on target: %s" % device)
283            with tvm.target.Target(device):
284                s = tvm.topi.testing.get_injective_schedule(device)(C)
285
286            foo = tvm.build(s, [A, B, C], device, name="reverse_sequence")
287
288            data_nd = tvm.nd.array(in_data, ctx)
289            seq_lengths_nd = tvm.nd.array(seq_lengths, ctx)
290            out_nd = tvm.nd.empty(in_data.shape, ctx=ctx, dtype=A.dtype)
291            foo(data_nd, seq_lengths_nd, out_nd)
292            tvm.testing.assert_allclose(out_nd.asnumpy(), ref_res)
293
294        for device, ctx in tvm.testing.enabled_targets():
295            check_device(device, ctx)
296
297    indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
298    result = [[0, 5, 10, 15], [4, 1, 6, 11], [8, 9, 2, 7], [12, 13, 14, 3]]
299    verify_reverse_sequence(indata, [1, 2, 3, 4], 1, 0, np.array(result))
300    verify_reverse_sequence(indata, [1, 2, 3, 4], -1, 0, np.array(result))
301    verify_reverse_sequence(
302        indata.astype("float32"), [1, 2, 3, 4], 1, 0, np.array(result).astype("float32")
303    )
304
305    indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
306    result = [[0, 1, 2, 3], [5, 4, 6, 7], [10, 9, 8, 11], [15, 14, 13, 12]]
307    verify_reverse_sequence(indata, [1, 2, 3, 4], 0, 1, np.array(result))
308    verify_reverse_sequence(indata, [1, 2, 3, 4], 0, -1, np.array(result))
309    verify_reverse_sequence(
310        indata.astype("float32"), [1, 2, 3, 4], 0, 1, np.array(result).astype("float32")
311    )
312
313    indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
314    result = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [15, 14, 13, 12]]
315    verify_reverse_sequence(indata, [-1, 0, 1, 5], 0, 1, np.array(result))
316
317    indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32")
318    result = [
319        [
320            [[18, 19, 20], [21, 22, 23], [24, 25, 26]],
321            [[9, 10, 11], [12, 13, 14], [15, 16, 17]],
322            [[0, 1, 2], [3, 4, 5], [6, 7, 8]],
323        ],
324        [
325            [[45, 46, 47], [48, 49, 50], [51, 52, 53]],
326            [[36, 37, 38], [39, 40, 41], [42, 43, 44]],
327            [[27, 28, 29], [30, 31, 32], [33, 34, 35]],
328        ],
329    ]
330    verify_reverse_sequence(indata, [3, 3], 0, 1, np.array(result))
331
332    indata = np.array(np.arange(0, 54)).reshape([2, 3, 3, 3]).astype("int32")
333    result = [
334        [
335            [[9, 10, 11], [21, 22, 23], [15, 16, 17]],
336            [[0, 1, 2], [12, 13, 14], [6, 7, 8]],
337            [[18, 19, 20], [3, 4, 5], [24, 25, 26]],
338        ],
339        [
340            [[36, 37, 38], [48, 49, 50], [42, 43, 44]],
341            [[27, 28, 29], [39, 40, 41], [33, 34, 35]],
342            [[45, 46, 47], [30, 31, 32], [51, 52, 53]],
343        ],
344    ]
345    verify_reverse_sequence(indata, [2, 3, 2], 2, 1, np.array(result))
346
347    indata = np.array(np.arange(0, 16)).reshape([4, 4]).astype("int32")
348    result = []
349    with pytest.raises(Exception) as execinfo:
350        verify_reverse_sequence(indata, [2, 3, 2, 4, 5], 1, 0, np.array(result))
351
352    assert (
353        "For reverse_sequnece seq_lengths size should match with dimension of batch axis,"
354        " but got dimension of batch_axis = 4, and seq_length size = 5" in execinfo.value.args[0]
355    )
356
357
358def verify_take(src_shape, indices_src, axis=None, mode="clip"):
359    src_dtype = "float32"
360    indices_dtype = "int32"
361    indices_src = np.array(indices_src, dtype=indices_dtype)
362    A = te.placeholder(shape=src_shape, dtype=src_dtype, name="A")
363    indices = te.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
364    if axis is None:
365        out_tensor = topi.take(a=A, indices=indices, mode=mode)
366    else:
367        out_tensor = topi.take(a=A, indices=indices, axis=axis, mode=mode)
368
369    def check_device(device):
370        ctx = tvm.context(device, 0)
371        if not tvm.testing.device_enabled(device):
372            print("Skip because %s is not enabled" % device)
373            return
374        print("Running on target: %s" % device)
375        with tvm.target.Target(device):
376            s = tvm.topi.testing.get_injective_schedule(device)(out_tensor)
377
378        foo = tvm.build(s, [A] + [indices] + [out_tensor], device, name="take")
379        shape_size = 1
380        for i in range(len(src_shape)):
381            shape_size = shape_size * src_shape[i]
382        data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
383
384        if axis is None:
385            np_mode = "raise" if mode == "fast" else mode
386            out_npys = np.take(data_npy, indices_src, mode=np_mode)
387        else:
388            np_mode = "raise" if mode == "fast" else mode
389            out_npys = np.take(data_npy, indices_src, axis=axis, mode=np_mode)
390        data_nd = tvm.nd.array(data_npy, ctx)
391        indices_nd = tvm.nd.array(indices_src, ctx)
392        out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
393        foo(data_nd, indices_nd, out_nd)
394        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys)
395
396    for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
397        check_device(device)
398
399
400def verify_strided_slice(in_shape, begin, end, strides=None):
401    A = te.placeholder(shape=in_shape, name="A")
402    strides = [1, 1, 1] if strides is None else strides
403    B = topi.strided_slice(A, begin, end, strides) + 1
404
405    def check_device(device):
406        ctx = tvm.context(device, 0)
407        if not tvm.testing.device_enabled(device):
408            print("Skip because %s is not enabled" % device)
409            return
410        print("Running on target: %s" % device)
411        with tvm.target.Target(device):
412            s = tvm.topi.testing.get_injective_schedule(device)(B)
413
414        foo = tvm.build(s, [A, B], device, name="stride_slice")
415        x_np = np.random.uniform(size=in_shape).astype(A.dtype)
416        out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) + 1
417        data_nd = tvm.nd.array(x_np, ctx)
418        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
419        foo(data_nd, out_nd)
420        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
421
422    for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
423        check_device(device)
424
425
426def verify_strided_set(in_shape, v_shape, begin, end, strides=None):
427    A = te.placeholder(shape=in_shape, name="A")
428    V = te.placeholder(shape=v_shape, name="V")
429    b = te.placeholder(shape=(len(begin),), name="b", dtype="int32")
430    e = te.placeholder(shape=(len(end),), name="e", dtype="int32")
431    if strides is not None:
432        st = te.placeholder(shape=(len(strides),), name="st", dtype="int32")
433        B = topi.strided_set(A, V, b, e, st) + 1
434    else:
435        B = topi.strided_set(A, V, b, e) + 1
436
437    def check_device(device):
438        ctx = tvm.context(device, 0)
439        if not tvm.testing.device_enabled(device):
440            print("Skip because %s is not enabled" % device)
441            return
442        print("Running on target: %s" % device)
443        with tvm.target.Target(device):
444            s = tvm.topi.testing.get_injective_schedule(device)(B)
445
446        if strides is not None:
447            foo = tvm.build(s, [A, V, b, e, st, B], device, name="stride_set")
448            s_np = np.asarray(strides).astype("int32")
449            s_nd = tvm.nd.array(s_np, ctx)
450        else:
451            foo = tvm.build(s, [A, V, b, e, B], device, name="stride_set")
452        x_np = np.random.uniform(size=in_shape).astype(A.dtype)
453        v_np = np.random.uniform(size=v_shape).astype(V.dtype)
454        b_np = np.asarray(begin).astype("int32")
455        e_np = np.asarray(end).astype("int32")
456        out_npy = tvm.topi.testing.strided_set_python(x_np, v_np, begin, end, strides) + 1
457        data_nd = tvm.nd.array(x_np, ctx)
458        v_nd = tvm.nd.array(v_np, ctx)
459        b_nd = tvm.nd.array(b_np, ctx)
460        e_nd = tvm.nd.array(e_np, ctx)
461        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
462        if strides is not None:
463            foo(data_nd, v_nd, b_nd, e_nd, s_nd, out_nd)
464        else:
465            foo(data_nd, v_nd, b_nd, e_nd, out_nd)
466        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
467
468    for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
469        check_device(device)
470
471
472def verify_gather(data, axis, indices):
473    data = np.asarray(data)
474    indices = np.asarray(indices)
475
476    var_data = te.placeholder(shape=data.shape, dtype=data.dtype.name, name="data")
477    var_indices = te.placeholder(shape=indices.shape, dtype=indices.dtype.name, name="indices")
478    out_tensor = topi.gather(var_data, axis, var_indices)
479
480    def check_device(device, ctx):
481        print("Running on target: %s" % device)
482        with tvm.target.Target(device):
483            s = tvm.topi.testing.get_injective_schedule(device)(out_tensor)
484
485        func = tvm.build(s, [var_data, var_indices, out_tensor], device, name="gather")
486        out_npys = tvm.topi.testing.gather_python(data, axis, indices)
487
488        data_nd = tvm.nd.array(data, ctx)
489        indices_nd = tvm.nd.array(indices, ctx)
490        out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=data.dtype.name)
491        func(data_nd, indices_nd, out_nd)
492        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys)
493
494    for device, ctx in tvm.testing.enabled_targets():
495        check_device(device, ctx)
496
497
498def verify_gather_nd(src_shape, indices_src, indices_dtype):
499    src_dtype = "float32"
500    indices_src = np.array(indices_src, dtype=indices_dtype)
501    A = te.placeholder(shape=src_shape, dtype=src_dtype, name="A")
502    indices = te.placeholder(shape=indices_src.shape, dtype=indices_dtype, name="indices")
503    out_tensor = topi.gather_nd(a=A, indices=indices)
504
505    def check_device(device, ctx):
506        print("Running on target: %s" % device)
507        with tvm.target.Target(device):
508            s = tvm.topi.testing.get_injective_schedule(device)(out_tensor)
509
510        func = tvm.build(s, [A, indices, out_tensor], device, name="take")
511        shape_size = 1
512        for i in range(len(src_shape)):
513            shape_size = shape_size * src_shape[i]
514        data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
515        out_npys = tvm.topi.testing.gather_nd_python(data_npy, indices_src)
516
517        data_nd = tvm.nd.array(data_npy, ctx)
518        indices_nd = tvm.nd.array(indices_src, ctx)
519        out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
520        func(data_nd, indices_nd, out_nd)
521        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npys)
522
523    for device, ctx in tvm.testing.enabled_targets():
524        check_device(device, ctx)
525
526
527def verify_arange(start, stop, step):
528    if start is None and step is None:
529        A = topi.arange(stop)
530        a_np = np.arange(stop)
531    elif start is None:
532        A = topi.arange(stop, step=step)
533        a_np = np.arange(stop, step=step)
534    elif step is None:
535        A = topi.arange(start, stop)
536        a_np = np.arange(start, stop)
537    else:
538        A = topi.arange(start, stop, step)
539        a_np = np.arange(start, stop, step)
540
541    def check_device(device, ctx):
542        print("Running on target: %s" % device)
543        with tvm.target.Target(device):
544            s = tvm.topi.testing.get_injective_schedule(device)(A)
545        f = tvm.build(s, [A], device, name="arange")
546        a_nd = tvm.nd.empty(a_np.shape, dtype="float32", ctx=ctx)
547        f(a_nd)
548        tvm.testing.assert_allclose(a_nd.asnumpy(), a_np)
549
550    for device, ctx in tvm.testing.enabled_targets():
551        check_device(device, ctx)
552
553
554def verify_repeat(in_shape, repeats, axis):
555    A = te.placeholder(shape=in_shape, name="A")
556    B = topi.repeat(A, repeats, axis)
557
558    def check_device(device, ctx):
559        print("Running on target: %s" % device)
560        with tvm.target.Target(device):
561            s = tvm.topi.testing.get_broadcast_schedule(device)(B)
562        foo = tvm.build(s, [A, B], device, name="repeat")
563        data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
564        out_npy = np.repeat(data_npy, repeats, axis)
565        data_nd = tvm.nd.array(data_npy, ctx)
566        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
567        foo(data_nd, out_nd)
568        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
569
570    for device, ctx in tvm.testing.enabled_targets():
571        check_device(device, ctx)
572
573
574def verify_tile(in_shape, reps):
575    A = te.placeholder(shape=in_shape, name="A")
576    B = topi.tile(A, reps)
577
578    def check_device(device, ctx):
579        print("Running on target: %s" % device)
580        with tvm.target.Target(device):
581            s = tvm.topi.testing.get_broadcast_schedule(device)(B)
582        foo = tvm.build(s, [A, B], device, name="tile")
583        data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
584        out_npy = np.tile(data_npy, reps)
585        data_nd = tvm.nd.array(data_npy, ctx)
586        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(B.dtype), ctx)
587        foo(data_nd, out_nd)
588        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
589
590    for device, ctx in tvm.testing.enabled_targets():
591        check_device(device, ctx)
592
593
594def verify_where(in_shape):
595    Cond = te.placeholder(shape=in_shape, name="cond")
596    dtype = Cond.dtype
597    A = te.placeholder(shape=in_shape, name="A")
598    B = te.placeholder(shape=in_shape, name="B")
599    C = topi.where(Cond, A, B)
600
601    def check_device(device, ctx):
602        print("Running on target: %s" % device)
603        with tvm.target.Target(device):
604            s = tvm.topi.testing.get_broadcast_schedule(device)(C)
605        f = tvm.build(s, [Cond, A, B, C], device, name="where")
606        cond_npy = np.random.uniform(low=-1, high=1, size=in_shape).astype(dtype)
607        x_npy = np.random.uniform(size=in_shape).astype(dtype)
608        y_npy = np.random.uniform(size=in_shape).astype(dtype)
609        out_npy = np.where(cond_npy, x_npy, y_npy)
610        cond_nd = tvm.nd.array(cond_npy, ctx)
611        x_nd = tvm.nd.array(x_npy, ctx)
612        y_nd = tvm.nd.array(y_npy, ctx)
613        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(C.dtype), ctx)
614        f(cond_nd, x_nd, y_nd, out_nd)
615        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
616
617    for device, ctx in tvm.testing.enabled_targets():
618        check_device(device, ctx)
619
620
621def verify_one_hot(indices_shape, depth, on_value, off_value, axis, dtype):
622    indices = te.placeholder(shape=indices_shape, name="indices", dtype="int32")
623    on_value_const = tvm.tir.const(on_value, dtype)
624    off_value_const = tvm.tir.const(off_value, dtype)
625    one_hot_result = topi.transform.one_hot(
626        indices, on_value_const, off_value_const, depth, axis, dtype
627    )
628
629    def check_device(device, ctx):
630        print("Running on target: %s" % device)
631        with tvm.target.Target(device):
632            s = tvm.topi.testing.get_injective_schedule(device)(one_hot_result)
633        fn = tvm.build(s, [indices, one_hot_result], device, name="one_hot")
634        indices_npy = np.random.randint(0, depth, size=indices_shape).astype(indices.dtype)
635        out_npy = tvm.topi.testing.one_hot(indices_npy, on_value, off_value, depth, axis, dtype)
636        indices_nd = tvm.nd.array(indices_npy, ctx)
637        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(one_hot_result.dtype), ctx)
638        fn(indices_nd, out_nd)
639        out_topi = out_nd.asnumpy()
640        tvm.testing.assert_allclose(out_topi, out_npy)
641
642    for device, ctx in tvm.testing.enabled_targets():
643        check_device(device, ctx)
644
645
646def verify_unravel_index(indices, shape, dtype):
647    x_data = np.array(indices).astype(dtype)
648    y_data = np.array(shape).astype(dtype)
649    if len(x_data.shape) == 1:
650        dst_shape = [y_data.shape[0], x_data.shape[0]]
651    else:
652        dst_shape = [y_data.shape[0]]
653
654    X = te.placeholder(shape=x_data.shape, dtype=dtype, name="X")
655    Y = te.placeholder(shape=y_data.shape, dtype=dtype, name="Y")
656    Z = topi.unravel_index(X, Y)
657
658    def check_device(device, ctx):
659        print("Running on target: %s" % device)
660        with tvm.target.Target(device):
661            s = tvm.topi.testing.get_injective_schedule(device)(Z)
662        foo = tvm.build(s, [X, Y, Z], device, name="unravel_index")
663
664        out_npy = np.unravel_index(x_data, y_data)
665        datax_nd = tvm.nd.array(x_data, ctx)
666        datay_nd = tvm.nd.array(y_data, ctx)
667        out_nd = tvm.nd.empty(dst_shape, ctx=ctx, dtype=Z.dtype)
668        foo(datax_nd, datay_nd, out_nd)
669        tvm.testing.assert_allclose(out_nd.asnumpy(), out_npy)
670
671    for device, ctx in tvm.testing.enabled_targets():
672        check_device(device, ctx)
673
674
675def verify_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape, xpected):
676    sparse_indices_data = np.array(sparse_indices)
677    sparse_values_data = np.array(sparse_values)
678    output_shape_data = np.array(output_shape)
679    default_value_data = np.array(default_value)
680
681    A = te.placeholder(
682        shape=sparse_indices_data.shape, name="sparse_indices", dtype=str(sparse_indices_data.dtype)
683    )
684    B = te.placeholder(
685        shape=sparse_values_data.shape, name="sparse_values", dtype=str(sparse_values_data.dtype)
686    )
687    if default_value is None:
688        args = [A, B]
689        D = topi.sparse_to_dense(A, output_shape, B)
690    else:
691        C = te.placeholder(shape=(), name="default_value", dtype=str(default_value_data.dtype))
692        args = [A, B, C]
693        D = topi.sparse_to_dense(A, output_shape, B, C)
694
695    def check_device(device, ctx):
696        print("Running on target: %s" % device)
697        with tvm.target.Target(device):
698            s = tvm.topi.testing.get_injective_schedule(device)(D)
699
700        foo = tvm.build(s, args + [D], device, name="sparse_to_dense")
701
702        sparse_indices_nd = tvm.nd.array(sparse_indices_data, ctx)
703        sparse_values_nd = tvm.nd.array(sparse_values_data, ctx)
704        out_nd = tvm.nd.empty(output_shape_data, ctx=ctx, dtype=B.dtype)
705
706        if default_value is None:
707            foo(sparse_indices_nd, sparse_values_nd, out_nd)
708        else:
709            default_value_nd = tvm.nd.array(default_value_data, ctx)
710            foo(sparse_indices_nd, sparse_values_nd, default_value_nd, out_nd)
711
712        tvm.testing.assert_allclose(out_nd.asnumpy(), np.array(xpected))
713
714    for device, ctx in tvm.testing.enabled_targets():
715        check_device(device, ctx)
716
717
718def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
719    input = te.placeholder(shape=input_shape, name="input", dtype=dtype)
720    diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype)
721    matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, k, align)
722
723    def check_device(device, ctx):
724        ctx = tvm.context(device, 0)
725        print("Running on target: %s" % device)
726        with tvm.target.Target(device):
727            s = tvm.topi.testing.get_injective_schedule(device)(matrix_set_diag_result)
728        fn = tvm.build(s, [input, diagonal, matrix_set_diag_result], device, name="matrix_set_diag")
729        input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype)
730        diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype)
731        out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy, k, align)
732        input_nd = tvm.nd.array(input_npy, ctx)
733        diagonal_nd = tvm.nd.array(diagonal_npy, ctx)
734        out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(matrix_set_diag_result.dtype), ctx)
735        fn(input_nd, diagonal_nd, out_nd)
736        out_topi = out_nd.asnumpy()
737        tvm.testing.assert_allclose(out_topi, out_npy)
738
739    for target, ctx in tvm.testing.enabled_targets():
740        check_device(target, ctx)
741
742
743def verify_adv_index(data_shape, index_shapes):
744    dtype = "float32"
745    data = te.placeholder(shape=data_shape, name="data", dtype=dtype)
746    indices = []
747    np_data = np.random.uniform(size=data_shape).astype(dtype)
748    np_indices = []
749    for i, index_shape in enumerate(index_shapes):
750        limit = data_shape[i]
751        np_indices.append(np.random.uniform(0, limit - 1, size=index_shape).astype("int64"))
752        indices.append(te.placeholder(shape=index_shape, name="index_{}".format(i), dtype="int64"))
753    np_out = np_data[tuple(np_indices)]
754    out = topi.adv_index(data, indices)
755
756    def check_device(device, ctx):
757        ctx = tvm.context(device, 0)
758        if not ctx.exist:
759            print("Skip because %s is not enabled" % device)
760            return
761        print("Running on target: %s" % device)
762        with tvm.target.create(device):
763            s = tvm.topi.testing.get_injective_schedule(device)(out)
764
765        func = tvm.build(s, [data] + indices + [out], device, name="adv_index")
766
767        nd_list = [tvm.nd.array(np_data, ctx)]
768        for np_index in np_indices:
769            nd_list.append(tvm.nd.array(np_index, ctx))
770        nd_list.append(tvm.nd.empty(out.shape, ctx=ctx, dtype=data.dtype))
771
772        func(*nd_list)
773        tvm.testing.assert_allclose(nd_list[-1].asnumpy(), np.array(np_out))
774
775    for target, ctx in tvm.testing.enabled_targets():
776        check_device(target, ctx)
777
778
779@tvm.testing.uses_gpu
780def test_strided_slice():
781    verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
782    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
783    verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1])
784    verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2])
785    verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
786    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
787    verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])
788
789
790@tvm.testing.uses_gpu
791def test_strided_set():
792    verify_strided_set((3, 4, 3), (3, 2, 2), [0, 3, 0], [4, 1, 4], [1, -1, 2])
793    verify_strided_set((3, 4, 3), (3, 1, 2), [0, 0, 0], [4, -5, 4], [1, -1, 2])
794    verify_strided_set((3, 4, 3), (1, 3, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
795    verify_strided_set((3, 4, 3), (1, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1])
796    verify_strided_set((3, 4, 3), (1, 2, 2), [1, 0, 0], [2, 2, 3], [1, 1, 2])
797    verify_strided_set((3, 4, 3), (1, 2, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
798    verify_strided_set((3, 4, 3), (1, 2, 3), [1, 1, 0], [2, 3, 3], [1])
799    verify_strided_set((3, 4, 3), (2, 3, 3), [1, 1, 0], [4, 4, 3])
800    verify_strided_set((3, 4, 3), (2, 3, 3), [1, 1], [4, 4, 3])
801
802
803@tvm.testing.uses_gpu
804def test_expand_dims():
805    verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
806    verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
807
808
809@tvm.testing.uses_gpu
810def test_reinterpret():
811    verify_reinterpret((1000,), "float32", "int32", lambda shape: np.random.randn(*shape) * 1000)
812    verify_reinterpret((1000,), "float16", "int16", lambda shape: np.random.randn(*shape) * 100)
813    verify_reinterpret(
814        (1000,), "int16", "uint16", lambda shape: np.random.randint(-1000, 1000, size=shape)
815    )
816    verify_reinterpret(
817        (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape)
818    )
819    verify_reinterpret(
820        (1000,), "uint32", "int32", lambda shape: np.random.randint(0, 2 ** 32 - 1, size=shape)
821    )
822
823
824@tvm.testing.uses_gpu
825def test_transpose():
826    verify_transpose((3, 10, 2), (1, 0, 2))
827    verify_transpose((3, 10, 5), (2, 0, 1))
828    verify_transpose((3, 10), None)
829
830
831@tvm.testing.uses_gpu
832def test_reshape():
833    verify_reshape((1, 2, 3, 4), (2, 3, 4))
834    verify_reshape((4, 2, 3, 4), (2, 4, 12))
835    verify_reshape((4, 2, 3, 4), (2, 48))
836    verify_reshape((16,), (2, 2, 2, 2))
837    verify_reshape((4, 0), (2, 0, 2))
838
839
840@tvm.testing.uses_gpu
841def test_where():
842    verify_where((1, 2, 3, 4))
843
844
845@tvm.testing.requires_gpu
846def test_squeeze():
847    verify_squeeze((1, 2, 3, 4), 0)
848    verify_squeeze((1, 2, 1, 4), None)
849    verify_squeeze((1, 1, 1, 4), (1, 2))
850    verify_squeeze((1, 1, 1, 1), None)
851
852    # a special case to trigger inline let expression
853    A = te.placeholder((2,), "float32", "A")
854    E = topi.squeeze(A)
855    C = te.compute((1,), lambda i: E[(2 * A[0] - 1).astype("int32")])
856    for device in ["cuda", "opencl"]:
857        ctx = tvm.context(device, 0)
858        if tvm.testing.device_enabled(device):
859            with tvm.target.Target(device):
860                s = tvm.topi.testing.get_injective_schedule(device)(C)
861                func = tvm.build(s, [A, C])
862            a = tvm.nd.array(np.array((1, 2)).astype("float32"), ctx=ctx)
863            c = tvm.nd.empty((1,), dtype="float32", ctx=ctx)
864            func(a, c)
865            assert c.asnumpy()[0] == 2
866
867
868@tvm.testing.uses_gpu
869def test_concatenate():
870    verify_concatenate([(2,), (2,), (2,)], -1)
871    verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1)
872    verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1)
873    verify_concatenate([(5, 6, 7, 3), (16, 6, 7, 3), (12, 6, 7, 3), (8, 6, 7, 3), (2, 6, 7, 3)], 0)
874    verify_concatenate([(1, 14400), (1, 2400), (1, 640), (1, 240)], 1)
875
876
877@tvm.testing.uses_gpu
878def test_stack():
879    verify_stack([(2,), (2,), (2,)], -1)
880    verify_stack([(2,), (2,), (2,)], 1)
881    verify_stack([(2,), (2,), (2,)], 0)
882    verify_stack([(2, 2, 4), (2, 2, 4), (2, 2, 4)], 1)
883    verify_stack([(2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4), (2, 2, 3, 4)], -1)
884
885
886@tvm.testing.uses_gpu
887def test_split():
888    verify_split((2, 12, 3), 3, 1)
889    verify_split((2, 12, 3), [2, 4], 1)
890    verify_split((10, 12, 24), [5, 7, 9], -1)
891
892
893@tvm.testing.uses_gpu
894def test_flip():
895    verify_flip((3, 4, 3), 1)
896    verify_flip((3, 4, 3), 0)
897    verify_flip((3, 4, 3), 2)
898    verify_flip((3, 4, 3), -1)
899    verify_flip((3, 4, 3), -3)
900    verify_flip((3, 4, 3), -2)
901
902
903@tvm.testing.requires_llvm
904def test_expand_like():
905    verify_expand_like((3,), (2, 3), [0])
906    verify_expand_like((2,), (2, 3), [1])
907    verify_expand_like((3, 4), (3, 5, 4), [1])
908    verify_expand_like((5, 7), (5, 6, 7, 8), [1, 3])
909
910
911@tvm.testing.uses_gpu
912def test_take():
913    verify_take((4,), [1])
914    verify_take((4,), [[0, 1, 2, 3]])
915    verify_take((3, 3, 3), [[11, 25]])
916    verify_take((4,), [[0, 1], [2, 3]])
917    verify_take((4,), [1], 0)
918    verify_take((2, 2), [[[1, 0], [0, 1]]], 0)
919    verify_take((2, 2), [[[1, 0], [0, 1]]], 1)
920    verify_take((4, 3, 5, 6), [[2, 1, 0, 0]], -2)
921    verify_take((3, 4), [-5, 20])
922    verify_take((3, 4), [-5, 20], mode="wrap")
923    verify_take((3, 4), [-1, 2], axis=0)
924    verify_take((3, 4), [-1, 2], axis=0, mode="wrap")
925    verify_take((3, 4), [-1, 2], axis=1)
926    verify_take((3, 4), [-1, 2], axis=1, mode="wrap")
927    verify_take((3, 3, 3), [[11, 25]], mode="fast")
928    verify_take((3, 4), [0, 2], axis=0, mode="fast")
929    verify_take((3, 4), [0, 2], axis=1, mode="fast")
930
931
932@tvm.testing.uses_gpu
933def test_gather():
934    verify_gather([[1, 2], [3, 4]], 1, [[0, 0], [1, 0]])
935    verify_gather(np.random.randn(4, 7, 5), 0, np.random.randint(low=0, high=4, size=(1, 7, 5)))
936    verify_gather(np.random.randn(4, 7, 5), 0, np.random.randint(low=0, high=4, size=(4, 7, 5)))
937    verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5)))
938    verify_gather(np.random.randn(4, 7, 5), 1, np.random.randint(low=0, high=7, size=(4, 10, 5)))
939    verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 2)))
940    verify_gather(np.random.randn(4, 7, 5), 2, np.random.randint(low=0, high=5, size=(4, 7, 10)))
941
942
943@tvm.testing.uses_gpu
944def test_gather_nd():
945    for indices_dtype in ["int32", "float32"]:
946        verify_gather_nd((4,), [[1.8]], indices_dtype)
947        verify_gather_nd((4,), [[1, 3, 2]], indices_dtype)
948        verify_gather_nd((2, 3), [[1]], indices_dtype)
949        verify_gather_nd((2, 3), [[1], [0]], indices_dtype)
950        verify_gather_nd((2, 3), [[1, 0], [0, 2]], indices_dtype)
951        verify_gather_nd((2, 3, 4), [[1, 0], [0, 2]], indices_dtype)
952        verify_gather_nd((2, 3, 4), [[1, 0], [0, 2], [3, 1]], indices_dtype)
953        verify_gather_nd(
954            (2, 3, 4), [[[1, 0], [0, 1]], [[0, 2], [1, 2]], [[3, 1], [0, 2]]], indices_dtype
955        )
956        verify_gather_nd((2, 3, 4, 5), [[1, 0], [0, 2]], indices_dtype)
957        verify_gather_nd((2, 3, 4, 5), [[1, 0], [2, 1], [3, 2], [4, 2]], indices_dtype)
958
959
960@tvm.testing.uses_gpu
961def test_arange():
962    verify_arange(None, 20, None)
963    verify_arange(None, 20, 2)
964    verify_arange(1, 20, None)
965    verify_arange(1, 20, 2)
966    verify_arange(1, 20, 1.5)
967    verify_arange(1, 20.5, None)
968    verify_arange(1, 20, 3)
969    verify_arange(20, 1, -1)
970    verify_arange(20, 1, -1.5)
971
972
973@tvm.testing.uses_gpu
974def test_repeat():
975    verify_repeat((2,), 1, 0)
976    verify_repeat((3, 2), 2, 0)
977    verify_repeat((3, 2, 4), 3, 1)
978    verify_repeat((1, 3, 2, 4), 4, -1)
979
980
981@tvm.testing.uses_gpu
982def test_tile():
983    verify_tile((3, 2), (2, 3))
984    verify_tile((3, 2, 5), (2,))
985    verify_tile((3,), (2, 3, 3))
986    verify_tile((4, 0), (5,))
987
988
989@tvm.testing.uses_gpu
990def test_layout_transform():
991    in_shape = (1, 32, 8, 8)
992    A = te.placeholder(shape=in_shape, dtype="float32", name="A")
993    B = topi.layout_transform(A, "NCHW", "NCHW16c")
994
995    input = np.random.uniform(size=in_shape).astype(A.dtype)
996    output = np.transpose(input, axes=(0, 2, 3, 1))
997    output = np.reshape(output, newshape=(1, 8, 8, 2, 16))
998    output = np.transpose(output, axes=(0, 3, 1, 2, 4))
999
1000    def check_device(device, ctx):
1001        tvm_input = tvm.nd.array(input, ctx)
1002        tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=B.dtype)
1003        print("Running on target: %s" % device)
1004        with tvm.target.Target(device):
1005            s = tvm.topi.testing.get_injective_schedule(device)(B)
1006        f = tvm.build(s, [A, B], device, name="layout_transform")
1007        f(tvm_input, tvm_output)
1008        tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
1009
1010    for backend, ctx in tvm.testing.enabled_targets():
1011        check_device(backend, ctx)
1012
1013
1014@tvm.testing.uses_gpu
1015def test_shape():
1016    in_shape = (8, 7, 13)
1017    dtype = "int32"
1018    A = te.placeholder(shape=in_shape, dtype="float32", name="A")
1019    B = topi.shape(A, dtype)
1020
1021    input = np.random.uniform(size=in_shape).astype(A.dtype)
1022    output = np.asarray(in_shape).astype(dtype)
1023
1024    def check_device(device, ctx):
1025        tvm_input = tvm.nd.array(input, ctx)
1026        tvm_output = tvm.nd.empty(output.shape, ctx=ctx, dtype=dtype)
1027        print("Running on target: %s" % device)
1028        with tvm.target.Target(device):
1029            s = tvm.topi.testing.get_injective_schedule(device)(B)
1030        f = tvm.build(s, [A, B], device, name="shape")
1031        f(tvm_input, tvm_output)
1032        tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
1033
1034    for backend, ctx in tvm.testing.enabled_targets():
1035        check_device(backend, ctx)
1036
1037
1038@tvm.testing.uses_gpu
1039def test_sequence_mask():
1040    for in_shape in (5, 10), (3, 4, 5, 4):
1041        for axis in [0, 1]:
1042            for mask_value in [0.0, 1.0]:
1043                max_length = in_shape[axis]
1044                batch_size = in_shape[1 - axis]
1045                A = te.placeholder(shape=in_shape, dtype="float32", name="A")
1046                B = te.placeholder(shape=(batch_size,), dtype="int32", name="B")
1047                C = topi.sequence_mask(A, B, axis=axis, mask_value=mask_value)
1048                A_data = np.random.normal(0, 1, in_shape).astype(np.float32)
1049                B_data = np.random.randint(1, max_length, (batch_size,)).astype(np.int32)
1050                C_gt_data = tvm.topi.testing.sequence_mask(A_data, B_data, mask_value, axis)
1051
1052                def check_device(device, ctx):
1053                    tvm_A = tvm.nd.array(A_data, ctx)
1054                    tvm_B = tvm.nd.array(B_data, ctx)
1055                    tvm_C = tvm.nd.empty(in_shape, ctx=ctx, dtype="float32")
1056                    print("Running on target: %s" % device)
1057                    with tvm.target.Target(device):
1058                        s = tvm.topi.testing.get_injective_schedule(device)(C)
1059                    f = tvm.build(s, [A, B, C], device, name="SequenceMask")
1060                    f(tvm_A, tvm_B, tvm_C)
1061                    tvm.testing.assert_allclose(tvm_C.asnumpy(), C_gt_data)
1062
1063                for backend, ctx in tvm.testing.enabled_targets():
1064                    check_device(backend, ctx)
1065
1066
1067@tvm.testing.uses_gpu
1068def test_ndarray_size():
1069    in_shape = (5, 11, 7)
1070    dtype = "int32"
1071    A = te.placeholder(shape=in_shape, dtype="float32", name="A")
1072    B = topi.ndarray_size(A, dtype)
1073
1074    input = np.random.uniform(size=in_shape).astype(A.dtype)
1075    output = np.asarray(np.size(input)).astype(dtype)
1076
1077    def check_device(device, ctx):
1078        tvm_input = tvm.nd.array(input, ctx=ctx)
1079        tvm_output = tvm.nd.empty((), ctx=ctx, dtype=B.dtype)
1080        print("Running on target: %s" % device)
1081        with tvm.target.Target(device):
1082            s = tvm.topi.testing.get_injective_schedule(device)(B)
1083        f = tvm.build(s, [A, B], device, name="ndarray_size")
1084        f(tvm_input, tvm_output)
1085        tvm.testing.assert_allclose(tvm_output.asnumpy(), output)
1086
1087    for backend, ctx in tvm.testing.enabled_targets():
1088        check_device(backend, ctx)
1089
1090
1091@tvm.testing.uses_gpu
1092def test_where_fusion():
1093    """integration test that where and zeros should be properly inlined"""
1094
1095    def check_device(device, ctx):
1096        with tvm.target.Target(device):
1097            print("Running on target: %s" % device)
1098            conv2d_compute, conv2d_schedule = tvm.topi.testing.get_conv2d_nchw_implement(device)
1099            data = te.placeholder((2, 1, 2, 4), "int8", "data")
1100            w = te.placeholder((3, 1, 2, 2), "int8", "w")
1101            conv1 = conv2d_compute(data, w, 1, 0, 1, "int32")
1102            zeros = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(0, dtype="int32"))
1103            gt = topi.greater_equal(conv1, zeros)
1104            one = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(1, dtype="int32"))
1105            two = topi.full((2, 3, 1, 3), "int32", tvm.tir.const(2, dtype="int32"))
1106            where = topi.where(gt, one, two)
1107            add = topi.add(conv1, where)
1108            outs = [add]
1109            s = conv2d_schedule(outs)
1110            tvm.build(s, [data, w, add], target=backend)
1111
1112    for backend, ctx in tvm.testing.enabled_targets():
1113        check_device(backend, ctx)
1114
1115
1116@tvm.testing.uses_gpu
1117def test_one_hot():
1118    verify_one_hot((3,), 3, 1, 0, -1, "int32")
1119    verify_one_hot((3,), 3, 1.0, 0.0, -1, "float32")
1120    verify_one_hot((2, 2), 5, 2, -2, 0, "int32")
1121    verify_one_hot((2, 2), 5, 0.5, -0.5, 1, "float32")
1122    verify_one_hot((3, 2, 4, 5), 6, 1, 0, 1, "int32")
1123    verify_one_hot((3, 2, 4, 5), 6, 1.0, 0.0, 0, "float32")
1124
1125
1126@tvm.testing.uses_gpu
1127def test_unravel_index():
1128    for dtype in ["int32", "int64"]:
1129        verify_unravel_index([0, 1, 2, 3], [2, 2], dtype)
1130        verify_unravel_index([144], [5, 5, 5, 2], dtype)
1131        verify_unravel_index(144, [5, 5, 5, 2], dtype)
1132        verify_unravel_index([100, 13, 5], [5, 5, 5, 2], dtype)
1133
1134
1135@tvm.testing.uses_gpu
1136def test_sparse_to_dense():
1137    verify_sparse_to_dense(1, 3, 0, [5], [0, 3, 0, 0, 0])  # scalar
1138    verify_sparse_to_dense([0, 1, 4], [3, 3, 3], 0, [5], [3, 3, 0, 0, 3])  # vector
1139    verify_sparse_to_dense(
1140        [[0, 0], [1, 2]], [1, 2], 0, [3, 4], [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]
1141    )  # nXd
1142    verify_sparse_to_dense(
1143        [[0, 0, 0], [1, 2, 3]],
1144        [1, 2],
1145        4,
1146        [2, 3, 4],
1147        [[[1, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 4]], [[4, 4, 4, 4], [4, 4, 4, 4], [4, 4, 4, 2]]],
1148    )  # nXd
1149    verify_sparse_to_dense(
1150        [0, 1, 4], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1]
1151    )  # floats
1152    verify_sparse_to_dense(1, 3, None, [5], [0, 3, 0, 0, 0])  # default value not specified
1153
1154    # negative test cases
1155    # sparse indices should be ints
1156    # verify_sparse_to_dense([[0.1, 1.1, 4.1], [0,2,4]], [3.1, 3.1, 3.1], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
1157    # sparse_values should be 0d or 1d only
1158    # verify_sparse_to_dense([[0, 1, 4], [0, 2, 4]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
1159    # sparse_indices should not be > 2d tensor
1160    # verify_sparse_to_dense([[[[0, 1, 4], [0, 2, 4]]]], [[[3.1, 3.1, 3.1]]], 3.5, [5], [3.1, 3.1, 3.5, 3.5, 3.1])
1161
1162
1163@tvm.testing.uses_gpu
1164def test_matrix_set_diag():
1165    for dtype in ["float32", "int32"]:
1166        verify_matrix_set_diag((2, 2), (2,), dtype)
1167        verify_matrix_set_diag((4, 3, 3), (4, 3), dtype)
1168        verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1)
1169        verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_RIGHT")
1170        verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_LEFT")
1171        verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "RIGHT_RIGHT")
1172
1173
1174@tvm.testing.uses_gpu
1175def test_adv_index():
1176    verify_adv_index((3, 4, 5), [(2,), (2,), (1,)])
1177    verify_adv_index((10, 15, 5), [(1, 1), (2, 7)])
1178    verify_adv_index((10, 5, 15), [(1, 2, 1), (1, 2, 7)])
1179
1180
1181if __name__ == "__main__":
1182    test_strided_slice()
1183    test_concatenate()
1184    test_stack()
1185    test_transpose()
1186    test_expand_dims()
1187    test_reshape()
1188    test_where()
1189    test_squeeze()
1190    test_split()
1191    test_flip()
1192    test_expand_like()
1193    test_take()
1194    test_gather_nd()
1195    test_arange()
1196    test_layout_transform()
1197    test_repeat()
1198    test_tile()
1199    test_shape()
1200    test_sequence_mask()
1201    test_ndarray_size()
1202    test_where_fusion()
1203    test_one_hot()
1204    test_unravel_index()
1205    test_sparse_to_dense()
1206    test_matrix_set_diag()
1207    test_adv_index()
1208