1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17import tvm
18from tvm import te
19import tvm.testing
20
21target = "opencl"
22
23
24@tvm.testing.requires_gpu
25@tvm.testing.requires_opencl
26def test_opencl_ternary_expression():
27    def check_if_then_else(ctx, n, dtype):
28        A = te.placeholder((n,), name="A", dtype=dtype)
29        true_value = tvm.tir.const(1, dtype=dtype)
30        false_value = tvm.tir.const(3, dtype=dtype)
31        max_lhs = tvm.tir.const(2, dtype=dtype)
32        max_rhs = tvm.tir.if_then_else(A[0] > 0, true_value, false_value)
33        C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C")
34        s = te.create_schedule(C.op)
35        s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
36        fun = tvm.build(s, [A, C], target)
37
38        a = tvm.nd.empty((n,), A.dtype, ctx)
39        c = tvm.nd.empty((n,), A.dtype, ctx)
40        # Only need to test compiling here
41        fun(a, c)
42
43    def check_select(ctx, n, dtype):
44        A = te.placeholder((n,), name="A", dtype=dtype)
45        true_value = tvm.tir.const(1, dtype=dtype)
46        false_value = tvm.tir.const(3, dtype=dtype)
47        max_lhs = tvm.tir.const(2, dtype=dtype)
48        max_rhs = tvm.tir.Select(A[0] > 0, true_value, false_value)
49        C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C")
50        s = te.create_schedule(C.op)
51        s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
52        fun = tvm.build(s, [A, C], target)
53
54        a = tvm.nd.empty((n,), A.dtype, ctx)
55        c = tvm.nd.empty((n,), A.dtype, ctx)
56        # Only need to test compiling here
57        fun(a, c)
58
59    ctx = tvm.context(target, 0)
60
61    check_if_then_else(ctx, 1, "int8")
62    check_if_then_else(ctx, 1, "uint8")
63    check_if_then_else(ctx, 1, "int16")
64    check_if_then_else(ctx, 1, "uint16")
65    check_select(ctx, 1, "int8")
66    check_select(ctx, 1, "uint8")
67    check_select(ctx, 1, "int16")
68    check_select(ctx, 1, "uint16")
69
70
71@tvm.testing.requires_gpu
72@tvm.testing.requires_opencl
73def test_opencl_inf_nan():
74    def check_inf_nan(ctx, n, value, dtype):
75        A = te.placeholder((n,), name="A", dtype=dtype)
76        inf_value = tvm.tir.const(value, dtype=dtype)
77        C = te.compute((n,), lambda i: inf_value, name="C")
78        s = te.create_schedule(C.op)
79        s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
80        fun = tvm.build(s, [A, C], target)
81        a = tvm.nd.empty((n,), A.dtype, ctx)
82        c = tvm.nd.empty((n,), A.dtype, ctx)
83        # Only need to test compiling here
84        fun(a, c)
85
86    ctx = tvm.context(target, 0)
87
88    check_inf_nan(ctx, 1, -float("inf"), "float32")
89    check_inf_nan(ctx, 1, -float("inf"), "float64")
90    check_inf_nan(ctx, 1, float("inf"), "float32")
91    check_inf_nan(ctx, 1, float("inf"), "float64")
92    check_inf_nan(ctx, 1, float("nan"), "float32")
93    check_inf_nan(ctx, 1, float("nan"), "float64")
94
95
96@tvm.testing.requires_gpu
97@tvm.testing.requires_opencl
98def test_opencl_max():
99    def check_max(ctx, n, dtype):
100        A = te.placeholder((n,), name="A", dtype=dtype)
101        max_lhs = A[0] + tvm.tir.const(1, dtype=dtype)
102        max_rhs = tvm.tir.const(0, dtype=dtype)
103        C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C")
104        s = te.create_schedule(C.op)
105        s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
106        fun = tvm.build(s, [A, C], target)
107
108        a = tvm.nd.empty((n,), A.dtype, ctx)
109        c = tvm.nd.empty((n,), A.dtype, ctx)
110        # Only need to test compiling here
111        fun(a, c)
112
113    ctx = tvm.context(target, 0)
114
115    check_max(ctx, 1, "int8")
116    check_max(ctx, 1, "uint8")
117    check_max(ctx, 1, "int16")
118    check_max(ctx, 1, "uint16")
119    check_max(ctx, 1, "float32")
120    check_max(ctx, 1, "float64")
121
122
123if __name__ == "__main__":
124    test_opencl_ternary_expression()
125    test_opencl_inf_nan()
126