1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import tvm 19from .. import defop, AllTypes 20 21_bin_logic_op_map = { 22 'equal': lambda a, b, *idx: a[idx] == b[idx], 23 'not_equal': lambda a, b, *idx: a[idx] != b[idx], 24 'greater': lambda a, b, *idx: a[idx] > b[idx], 25 'less': lambda a, b, *idx: a[idx] < b[idx], 26 'greater_equal': lambda a, b, *idx: a[idx] >= b[idx], 27 'less_equal': lambda a, b, *idx: a[idx] <= b[idx], 28} 29 30 31def _compute_binary_logic(op, dtype, ndim): 32 a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], dtype=dtype, name='a') 33 b = tvm.placeholder([tvm.size_var() for _ in range(ndim)], dtype=dtype, name='b') 34 c = tvm.compute([tvm.size_var() for _ in range(ndim)], 35 lambda *idx: _bin_logic_op_map[op](a, b, *idx), name='c') 36 s = tvm.create_schedule(c.op) 37 return s, a, b, c 38 39 40_bin_logic_cpu_attrs = { 41 'compute_func': _compute_binary_logic, 42 'target': 'cpu', 43 'auto_broadcast': True, 44 'itype': AllTypes + ['bool'], 45 'ndim': list(range(6)) 46} 47 48_bin_logic_gpu_attrs = { 49 'compute_func': _compute_binary_logic, 50 'target': 'gpu', 51 'auto_broadcast': True, 52 'itype': AllTypes + ['bool'], 53 'ndim': list(range(6)) 54} 55 56 57def _binary_logic_cpu(compute_func, op, itype, ndim): 58 s, a, b, c = compute_func(op, itype, ndim) 59 axes = [axis for axis in c.op.axis] 60 fused = s[c].fuse(*axes) 61 s[c].parallel(fused) 62 return s, [a, b, c] 63 64 65def _binary_logic_gpu(compute_func, op, itype, ndim): 66 s, a, b, c = compute_func(op, itype, ndim) 67 axes = [axis for axis in c.op.axis] 68 fused = s[c].fuse(*axes) 69 bx, tx = s[c].split(fused, factor=64) 70 s[c].bind(bx, tvm.thread_axis('blockIdx.x')) 71 s[c].bind(tx, tvm.thread_axis('threadIdx.x')) 72 return s, [a, b, c] 73 74 75# register binary element-wise logic ops with broadcasting supported 76for op_name in _bin_logic_op_map.keys(): 77 defop(name='{}_cpu'.format(op_name), op=op_name, **_bin_logic_cpu_attrs)(_binary_logic_cpu) 78 defop(name='{}_gpu'.format(op_name), op=op_name, **_bin_logic_gpu_attrs)(_binary_logic_gpu) 79 80 81# Note that `b.dtype` is hard-coded as 'float64'. 82# We should always promote `a`'s elements to `b.dtype`. 83_bin_scalar_logic_op_map = { 84 'equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) == b, 85 'not_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) != b, 86 'greater_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) > b, 87 'less_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) < b, 88 'greater_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) >= b, 89 'less_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) <= b, 90} 91 92 93def _compute_binary_scalar_logic(op, dtype, ndim): 94 a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='a', dtype=dtype) 95 b = tvm.var('b', dtype='float64') 96 c = tvm.compute([tvm.size_var() for _ in range(ndim)], 97 lambda *idx: _bin_scalar_logic_op_map[op](a, b, *idx), name='c') 98 s = tvm.create_schedule(c.op) 99 return s, a, b, c 100 101 102_bin_scalar_logic_cpu_attrs = { 103 'compute_func': _compute_binary_scalar_logic, 104 'target': 'cpu', 105 'itype': AllTypes + ['bool'], 106 'ndim': list(range(6)) 107} 108 109_bin_scalar_logic_gpu_attrs = { 110 'compute_func': _compute_binary_scalar_logic, 111 'target': 'gpu', 112 'itype': AllTypes + ['bool'], 113 'ndim': list(range(6)) 114} 115 116 117# register binary element-wise scalar logic ops 118for op_name in _bin_scalar_logic_op_map.keys(): 119 defop(name='{}_cpu'.format(op_name), op=op_name, 120 **_bin_scalar_logic_cpu_attrs)(_binary_logic_cpu) 121 defop(name='{}_gpu'.format(op_name), op=op_name, 122 **_bin_scalar_logic_gpu_attrs)(_binary_logic_gpu) 123