1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import tvm
19from .. import defop, AllTypes
20
21_bin_logic_op_map = {
22    'equal': lambda a, b, *idx: a[idx] == b[idx],
23    'not_equal': lambda a, b, *idx: a[idx] != b[idx],
24    'greater': lambda a, b, *idx: a[idx] > b[idx],
25    'less': lambda a, b, *idx: a[idx] < b[idx],
26    'greater_equal': lambda a, b, *idx: a[idx] >= b[idx],
27    'less_equal': lambda a, b, *idx: a[idx] <= b[idx],
28}
29
30
31def _compute_binary_logic(op, dtype, ndim):
32    a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], dtype=dtype, name='a')
33    b = tvm.placeholder([tvm.size_var() for _ in range(ndim)], dtype=dtype, name='b')
34    c = tvm.compute([tvm.size_var() for _ in range(ndim)],
35                    lambda *idx: _bin_logic_op_map[op](a, b, *idx), name='c')
36    s = tvm.create_schedule(c.op)
37    return s, a, b, c
38
39
40_bin_logic_cpu_attrs = {
41    'compute_func': _compute_binary_logic,
42    'target': 'cpu',
43    'auto_broadcast': True,
44    'itype': AllTypes + ['bool'],
45    'ndim': list(range(6))
46}
47
48_bin_logic_gpu_attrs = {
49    'compute_func': _compute_binary_logic,
50    'target': 'gpu',
51    'auto_broadcast': True,
52    'itype': AllTypes + ['bool'],
53    'ndim': list(range(6))
54}
55
56
57def _binary_logic_cpu(compute_func, op, itype, ndim):
58    s, a, b, c = compute_func(op, itype, ndim)
59    axes = [axis for axis in c.op.axis]
60    fused = s[c].fuse(*axes)
61    s[c].parallel(fused)
62    return s, [a, b, c]
63
64
65def _binary_logic_gpu(compute_func, op, itype, ndim):
66    s, a, b, c = compute_func(op, itype, ndim)
67    axes = [axis for axis in c.op.axis]
68    fused = s[c].fuse(*axes)
69    bx, tx = s[c].split(fused, factor=64)
70    s[c].bind(bx, tvm.thread_axis('blockIdx.x'))
71    s[c].bind(tx, tvm.thread_axis('threadIdx.x'))
72    return s, [a, b, c]
73
74
75# register binary element-wise logic ops with broadcasting supported
76for op_name in _bin_logic_op_map.keys():
77    defop(name='{}_cpu'.format(op_name), op=op_name, **_bin_logic_cpu_attrs)(_binary_logic_cpu)
78    defop(name='{}_gpu'.format(op_name), op=op_name, **_bin_logic_gpu_attrs)(_binary_logic_gpu)
79
80
81# Note that `b.dtype` is hard-coded as 'float64'.
82# We should always promote `a`'s elements to `b.dtype`.
83_bin_scalar_logic_op_map = {
84    'equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) == b,
85    'not_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) != b,
86    'greater_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) > b,
87    'less_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) < b,
88    'greater_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) >= b,
89    'less_equal_scalar': lambda a, b, *idx: a[idx].astype(b.dtype) <= b,
90}
91
92
93def _compute_binary_scalar_logic(op, dtype, ndim):
94    a = tvm.placeholder([tvm.size_var() for _ in range(ndim)], name='a', dtype=dtype)
95    b = tvm.var('b', dtype='float64')
96    c = tvm.compute([tvm.size_var() for _ in range(ndim)],
97                    lambda *idx: _bin_scalar_logic_op_map[op](a, b, *idx), name='c')
98    s = tvm.create_schedule(c.op)
99    return s, a, b, c
100
101
102_bin_scalar_logic_cpu_attrs = {
103    'compute_func': _compute_binary_scalar_logic,
104    'target': 'cpu',
105    'itype': AllTypes + ['bool'],
106    'ndim': list(range(6))
107}
108
109_bin_scalar_logic_gpu_attrs = {
110    'compute_func': _compute_binary_scalar_logic,
111    'target': 'gpu',
112    'itype': AllTypes + ['bool'],
113    'ndim': list(range(6))
114}
115
116
117# register binary element-wise scalar logic ops
118for op_name in _bin_scalar_logic_op_map.keys():
119    defop(name='{}_cpu'.format(op_name), op=op_name,
120          **_bin_scalar_logic_cpu_attrs)(_binary_logic_cpu)
121    defop(name='{}_gpu'.format(op_name), op=op_name,
122          **_bin_scalar_logic_gpu_attrs)(_binary_logic_gpu)
123