1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, too-many-arguments, too-many-nested-blocks 18"""Argwhere operator""" 19import tvm 20from tvm import hybrid 21 22@hybrid.script 23def hybrid_argwhere_1d(output_shape, condition): 24 """Find the indices of elements of a 1-D tensor that are non-zero. 25 26 Parameters 27 ---------- 28 condition : tvm.Tensor 29 1-D tensor with boolean values. 30 31 Returns 32 ------- 33 out : tvm.Tensor 34 Indices of non-zero elements. 35 """ 36 a = output_tensor(output_shape, "int32") 37 a1 = condition.shape[0] 38 valid_index = 0 39 for i1 in range(a1): 40 if condition[i1] != 0: 41 a[valid_index, 0] = i1 42 valid_index += 1 43 return a 44 45@hybrid.script 46def hybrid_argwhere_2d(output_shape, condition): 47 """Find the indices of elements of a 2-D tensor that are non-zero. 48 49 Parameters 50 ---------- 51 condition : tvm.Tensor 52 2-D tensor with boolean values. 53 54 Returns 55 ------- 56 out : tvm.Tensor 57 Indices of non-zero elements. 58 """ 59 a = output_tensor(output_shape, "int32") 60 a1 = condition.shape[0] 61 a2 = condition.shape[1] 62 valid_index = 0 63 for i1 in range(a1): 64 for i2 in range(a2): 65 if condition[i1, i2] != 0: 66 a[valid_index, 0] = i1 67 a[valid_index, 1] = i2 68 valid_index += 1 69 return a 70 71@hybrid.script 72def hybrid_argwhere_3d(output_shape, condition): 73 """Find the indices of elements of a 3-D tensor that are non-zero. 74 75 Parameters 76 ---------- 77 condition : tvm.Tensor 78 3-D tensor with boolean values. 79 80 Returns 81 ------- 82 out : tvm.Tensor 83 Indices of non-zero elements. 84 """ 85 a = output_tensor(output_shape, "int32") 86 a1 = condition.shape[0] 87 a2 = condition.shape[1] 88 a3 = condition.shape[2] 89 valid_index = 0 90 for i1 in range(a1): 91 for i2 in range(a2): 92 for i3 in range(a3): 93 if condition[i1, i2, i3] != 0: 94 a[valid_index, 0] = i1 95 a[valid_index, 1] = i2 96 a[valid_index, 2] = i3 97 valid_index += 1 98 return a 99 100@hybrid.script 101def hybrid_argwhere_4d(output_shape, condition): 102 """Find the indices of elements of a 4-D tensor that are non-zero. 103 104 Parameters 105 ---------- 106 condition : tvm.Tensor 107 4-D tensor with boolean values. 108 109 Returns 110 ------- 111 out : tvm.Tensor 112 Indices of non-zero elements. 113 """ 114 a = output_tensor(output_shape, "int32") 115 a1 = condition.shape[0] 116 a2 = condition.shape[1] 117 a3 = condition.shape[2] 118 a4 = condition.shape[3] 119 valid_index = 0 120 for i1 in range(a1): 121 for i2 in range(a2): 122 for i3 in range(a3): 123 for i4 in range(a4): 124 if condition[i1, i2, i3, i4] != 0: 125 a[valid_index, 0] = i1 126 a[valid_index, 1] = i2 127 a[valid_index, 2] = i3 128 a[valid_index, 3] = i4 129 valid_index += 1 130 return a 131 132@hybrid.script 133def hybrid_argwhere_5d(output_shape, condition): 134 """Find the indices of elements of a 5-D tensor that are non-zero. 135 136 Parameters 137 ---------- 138 condition : tvm.Tensor 139 5-D tensor with boolean values. 140 141 Returns 142 ------- 143 out : tvm.Tensor 144 Indices of non-zero elements. 145 """ 146 a = output_tensor(output_shape, "int32") 147 a1 = condition.shape[0] 148 a2 = condition.shape[1] 149 a3 = condition.shape[2] 150 a4 = condition.shape[3] 151 a5 = condition.shape[4] 152 valid_index = 0 153 for i1 in range(a1): 154 for i2 in range(a2): 155 for i3 in range(a3): 156 for i4 in range(a4): 157 for i5 in range(a5): 158 if condition[i1, i2, i3, i4, i5] != 0: 159 a[valid_index, 0] = i1 160 a[valid_index, 1] = i2 161 a[valid_index, 2] = i3 162 a[valid_index, 3] = i4 163 a[valid_index, 4] = i5 164 valid_index += 1 165 return a 166 167@tvm.target.generic_func 168def argwhere(output_shape, condition): 169 """Find the indices of elements of a tensor that are non-zero. 170 171 Parameters 172 ---------- 173 condition : tvm.Tensor 174 Tensor with boolean values. 175 176 Returns 177 ------- 178 out : tvm.Tensor 179 Indices of non-zero elements. 180 """ 181 if len(condition.shape) == 1: 182 return hybrid_argwhere_1d(output_shape.shape, condition) 183 if len(condition.shape) == 2: 184 return hybrid_argwhere_2d(output_shape.shape, condition) 185 if len(condition.shape) == 3: 186 return hybrid_argwhere_3d(output_shape.shape, condition) 187 if len(condition.shape) == 4: 188 return hybrid_argwhere_4d(output_shape.shape, condition) 189 if len(condition.shape) == 5: 190 return hybrid_argwhere_5d(output_shape.shape, condition) 191 raise ValueError("Does not support rank higher than 5 in argwhere") 192