1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17import tvm 18from tvm import te 19import tvm.testing 20 21target = "opencl" 22 23 24@tvm.testing.requires_gpu 25@tvm.testing.requires_opencl 26def test_opencl_ternary_expression(): 27 def check_if_then_else(ctx, n, dtype): 28 A = te.placeholder((n,), name="A", dtype=dtype) 29 true_value = tvm.tir.const(1, dtype=dtype) 30 false_value = tvm.tir.const(3, dtype=dtype) 31 max_lhs = tvm.tir.const(2, dtype=dtype) 32 max_rhs = tvm.tir.if_then_else(A[0] > 0, true_value, false_value) 33 C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C") 34 s = te.create_schedule(C.op) 35 s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x")) 36 fun = tvm.build(s, [A, C], target) 37 38 a = tvm.nd.empty((n,), A.dtype, ctx) 39 c = tvm.nd.empty((n,), A.dtype, ctx) 40 # Only need to test compiling here 41 fun(a, c) 42 43 def check_select(ctx, n, dtype): 44 A = te.placeholder((n,), name="A", dtype=dtype) 45 true_value = tvm.tir.const(1, dtype=dtype) 46 false_value = tvm.tir.const(3, dtype=dtype) 47 max_lhs = tvm.tir.const(2, dtype=dtype) 48 max_rhs = tvm.tir.Select(A[0] > 0, true_value, false_value) 49 C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C") 50 s = te.create_schedule(C.op) 51 s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x")) 52 fun = tvm.build(s, [A, C], target) 53 54 a = tvm.nd.empty((n,), A.dtype, ctx) 55 c = tvm.nd.empty((n,), A.dtype, ctx) 56 # Only need to test compiling here 57 fun(a, c) 58 59 ctx = tvm.context(target, 0) 60 61 check_if_then_else(ctx, 1, "int8") 62 check_if_then_else(ctx, 1, "uint8") 63 check_if_then_else(ctx, 1, "int16") 64 check_if_then_else(ctx, 1, "uint16") 65 check_select(ctx, 1, "int8") 66 check_select(ctx, 1, "uint8") 67 check_select(ctx, 1, "int16") 68 check_select(ctx, 1, "uint16") 69 70 71@tvm.testing.requires_gpu 72@tvm.testing.requires_opencl 73def test_opencl_inf_nan(): 74 def check_inf_nan(ctx, n, value, dtype): 75 A = te.placeholder((n,), name="A", dtype=dtype) 76 inf_value = tvm.tir.const(value, dtype=dtype) 77 C = te.compute((n,), lambda i: inf_value, name="C") 78 s = te.create_schedule(C.op) 79 s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x")) 80 fun = tvm.build(s, [A, C], target) 81 a = tvm.nd.empty((n,), A.dtype, ctx) 82 c = tvm.nd.empty((n,), A.dtype, ctx) 83 # Only need to test compiling here 84 fun(a, c) 85 86 ctx = tvm.context(target, 0) 87 88 check_inf_nan(ctx, 1, -float("inf"), "float32") 89 check_inf_nan(ctx, 1, -float("inf"), "float64") 90 check_inf_nan(ctx, 1, float("inf"), "float32") 91 check_inf_nan(ctx, 1, float("inf"), "float64") 92 check_inf_nan(ctx, 1, float("nan"), "float32") 93 check_inf_nan(ctx, 1, float("nan"), "float64") 94 95 96@tvm.testing.requires_gpu 97@tvm.testing.requires_opencl 98def test_opencl_max(): 99 def check_max(ctx, n, dtype): 100 A = te.placeholder((n,), name="A", dtype=dtype) 101 max_lhs = A[0] + tvm.tir.const(1, dtype=dtype) 102 max_rhs = tvm.tir.const(0, dtype=dtype) 103 C = te.compute((n,), lambda i: tvm.te.max(max_lhs, max_rhs), name="C") 104 s = te.create_schedule(C.op) 105 s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x")) 106 fun = tvm.build(s, [A, C], target) 107 108 a = tvm.nd.empty((n,), A.dtype, ctx) 109 c = tvm.nd.empty((n,), A.dtype, ctx) 110 # Only need to test compiling here 111 fun(a, c) 112 113 ctx = tvm.context(target, 0) 114 115 check_max(ctx, 1, "int8") 116 check_max(ctx, 1, "uint8") 117 check_max(ctx, 1, "int16") 118 check_max(ctx, 1, "uint16") 119 check_max(ctx, 1, "float32") 120 check_max(ctx, 1, "float64") 121 122 123if __name__ == "__main__": 124 test_opencl_ternary_expression() 125 test_opencl_inf_nan() 126