1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks
18"""Argwhere operator"""
19import tvm
20from tvm import hybrid
21
22@hybrid.script
23def hybrid_argwhere_1d(output_shape, condition):
24    """Find the indices of elements of a 1-D tensor that are non-zero.
25
26    Parameters
27    ----------
28    condition : tvm.Tensor
29        1-D tensor with boolean values.
30
31    Returns
32    -------
33    out : tvm.Tensor
34        Indices of non-zero elements.
35    """
36    a = output_tensor(output_shape, "int32")
37    a1 = condition.shape[0]
38    valid_index = 0
39    for i1 in range(a1):
40        if condition[i1] != 0:
41            a[valid_index, 0] = i1
42            valid_index += 1
43    return a
44
45@hybrid.script
46def hybrid_argwhere_2d(output_shape, condition):
47    """Find the indices of elements of a 2-D tensor that are non-zero.
48
49    Parameters
50    ----------
51    condition : tvm.Tensor
52        2-D tensor with boolean values.
53
54    Returns
55    -------
56    out : tvm.Tensor
57        Indices of non-zero elements.
58    """
59    a = output_tensor(output_shape, "int32")
60    a1 = condition.shape[0]
61    a2 = condition.shape[1]
62    valid_index = 0
63    for i1 in range(a1):
64        for i2 in range(a2):
65            if condition[i1, i2] != 0:
66                a[valid_index, 0] = i1
67                a[valid_index, 1] = i2
68                valid_index += 1
69    return a
70
71@hybrid.script
72def hybrid_argwhere_3d(output_shape, condition):
73    """Find the indices of elements of a 3-D tensor that are non-zero.
74
75    Parameters
76    ----------
77    condition : tvm.Tensor
78        3-D tensor with boolean values.
79
80    Returns
81    -------
82    out : tvm.Tensor
83        Indices of non-zero elements.
84    """
85    a = output_tensor(output_shape, "int32")
86    a1 = condition.shape[0]
87    a2 = condition.shape[1]
88    a3 = condition.shape[2]
89    valid_index = 0
90    for i1 in range(a1):
91        for i2 in range(a2):
92            for i3 in range(a3):
93                if condition[i1, i2, i3] != 0:
94                    a[valid_index, 0] = i1
95                    a[valid_index, 1] = i2
96                    a[valid_index, 2] = i3
97                    valid_index += 1
98    return a
99
100@hybrid.script
101def hybrid_argwhere_4d(output_shape, condition):
102    """Find the indices of elements of a 4-D tensor that are non-zero.
103
104    Parameters
105    ----------
106    condition : tvm.Tensor
107        4-D tensor with boolean values.
108
109    Returns
110    -------
111    out : tvm.Tensor
112        Indices of non-zero elements.
113    """
114    a = output_tensor(output_shape, "int32")
115    a1 = condition.shape[0]
116    a2 = condition.shape[1]
117    a3 = condition.shape[2]
118    a4 = condition.shape[3]
119    valid_index = 0
120    for i1 in range(a1):
121        for i2 in range(a2):
122            for i3 in range(a3):
123                for i4 in range(a4):
124                    if condition[i1, i2, i3, i4] != 0:
125                        a[valid_index, 0] = i1
126                        a[valid_index, 1] = i2
127                        a[valid_index, 2] = i3
128                        a[valid_index, 3] = i4
129                        valid_index += 1
130    return a
131
132@hybrid.script
133def hybrid_argwhere_5d(output_shape, condition):
134    """Find the indices of elements of a 5-D tensor that are non-zero.
135
136    Parameters
137    ----------
138    condition : tvm.Tensor
139        5-D tensor with boolean values.
140
141    Returns
142    -------
143    out : tvm.Tensor
144        Indices of non-zero elements.
145    """
146    a = output_tensor(output_shape, "int32")
147    a1 = condition.shape[0]
148    a2 = condition.shape[1]
149    a3 = condition.shape[2]
150    a4 = condition.shape[3]
151    a5 = condition.shape[4]
152    valid_index = 0
153    for i1 in range(a1):
154        for i2 in range(a2):
155            for i3 in range(a3):
156                for i4 in range(a4):
157                    for i5 in range(a5):
158                        if condition[i1, i2, i3, i4, i5] != 0:
159                            a[valid_index, 0] = i1
160                            a[valid_index, 1] = i2
161                            a[valid_index, 2] = i3
162                            a[valid_index, 3] = i4
163                            a[valid_index, 4] = i5
164                            valid_index += 1
165    return a
166
167@tvm.target.generic_func
168def argwhere(output_shape, condition):
169    """Find the indices of elements of a tensor that are non-zero.
170
171    Parameters
172    ----------
173    condition : tvm.Tensor
174        Tensor with boolean values.
175
176    Returns
177    -------
178    out : tvm.Tensor
179        Indices of non-zero elements.
180    """
181    if len(condition.shape) == 1:
182        return hybrid_argwhere_1d(output_shape.shape, condition)
183    if len(condition.shape) == 2:
184        return hybrid_argwhere_2d(output_shape.shape, condition)
185    if len(condition.shape) == 3:
186        return hybrid_argwhere_3d(output_shape.shape, condition)
187    if len(condition.shape) == 4:
188        return hybrid_argwhere_4d(output_shape.shape, condition)
189    if len(condition.shape) == 5:
190        return hybrid_argwhere_5d(output_shape.shape, condition)
191    raise ValueError("Does not support rank higher than 5 in argwhere")
192