1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name 18"""Common topi utilities""" 19from __future__ import absolute_import as _abs 20from numbers import Integral 21 22import tvm 23from tvm.api import layout, bijective_layout 24from . import tag 25 26class InvalidShapeError(ValueError): 27 """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)""" 28 pass 29 30def traverse_inline(s, final_op, callback): 31 """Traverse computation graph and do auto inline 32 33 Parameters 34 ---------- 35 s: schedule 36 The schedule 37 final_op: Operation 38 The final output operator. 39 callback: callable 40 The callback function on each op 41 """ 42 visited = set() 43 44 def _traverse(op): 45 if op in visited: 46 return 47 visited.add(op) 48 if tag.is_injective(op.tag): 49 if op not in s.outputs: 50 s[op].compute_inline() 51 for tensor in op.input_tensors: 52 if isinstance(tensor.op, tvm.tensor.ComputeOp): 53 _traverse(tensor.op) 54 callback(op) 55 56 _traverse(final_op) 57 58 59def prod(x): 60 """Get the product of every items in the tuple. 61 62 Parameters 63 ---------- 64 x: tuple 65 Input tuple 66 67 Returns 68 ------- 69 value : Expr 70 The result value 71 """ 72 if not x: 73 return tvm.const(1, "int32") 74 res = x[0] 75 for i in range(1, len(x)): 76 res = res * x[i] 77 return res 78 79 80def get_const_int(expr): 81 """Verifies expr is integer and get the constant value. 82 83 Parameters 84 ---------- 85 expr : tvm.Expr or int 86 The input expression. 87 88 Returns 89 ------- 90 out_value : int 91 The output. 92 """ 93 if isinstance(expr, Integral): 94 return expr 95 if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): 96 expr = tvm.ir_pass.Simplify(expr) 97 if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): 98 raise ValueError("Expect value to be constant int") 99 return int(expr.value) 100 101 102def get_const_float(expr): 103 """Verifies expr is a floating point and get the constant value. 104 105 Parameters 106 ---------- 107 expr : tvm.Expr or float 108 The input expression. 109 110 Returns 111 ------- 112 out_value : float 113 The output. 114 """ 115 if isinstance(expr, float): 116 return float(expr) 117 if not isinstance(expr, tvm.expr.FloatImm): 118 expr = tvm.ir_pass.Simplify(expr) 119 if not isinstance(expr, tvm.expr.FloatImm): 120 raise ValueError("Expect value to be constant float") 121 return float(expr.value) 122 123 124def equal_const_int(expr, value): 125 """Returns if expr equals value. 126 127 Parameters 128 ---------- 129 expr : tvm.Expr 130 The input expression. 131 132 Returns 133 ------- 134 equal : bool 135 Whether they equals. 136 """ 137 if isinstance(expr, Integral): 138 return expr == value 139 if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): 140 expr = tvm.ir_pass.Simplify(expr) 141 if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): 142 return False 143 return expr.value == value 144 145 146def get_const_tuple(in_tuple): 147 """Verifies input tuple is IntImm or Var, returns tuple of int or Var. 148 149 Parameters 150 ---------- 151 in_tuple : tuple of Expr 152 The input. 153 154 Returns 155 ------- 156 out_tuple : tuple of int 157 The output. 158 """ 159 ret = [] 160 for elem in in_tuple: 161 if isinstance(elem, tvm.expr.Var): 162 ret.append(elem) 163 elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)): 164 elem = tvm.ir_pass.Simplify(elem) 165 if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)): 166 ret.append(elem) 167 else: 168 ret.append(get_const_int(elem)) 169 return tuple(ret) 170 171 172def get_float_tuple(in_tuple): 173 """Verifies input tuple is FloatImm, returns tuple of float. 174 175 Parameters 176 ---------- 177 in_tuple : tuple of Expr 178 The input. 179 180 Returns 181 ------- 182 out_tuple : tuple of float 183 The output. 184 """ 185 return tuple(get_const_float(elem) for elem in in_tuple) 186 187 188def simplify(expr): 189 """Simplify the expression if it is Expr, directly return if it is int. 190 191 Parameters 192 ---------- 193 expr : Expr or int 194 The input. 195 196 Returns 197 ------- 198 out : Expr or int 199 The simplified output 200 """ 201 return tvm.ir_pass.Simplify(expr) if isinstance(expr, tvm.expr.Expr) else expr 202 203 204def ravel_index(indices, shape): 205 """Flatten the index tuple to 1D 206 207 Parameters 208 ---------- 209 indices : tuple of int or tvm.expr.IntImm 210 The input coordinates 211 212 shape : tuple of int 213 Shape of the tensor. 214 215 Returns 216 ------- 217 idx : int or Expr 218 The index after flattening 219 """ 220 idx = None 221 for i, (shape_val, ind) in enumerate(zip(shape, indices)): 222 if i != 0: 223 idx = idx * shape_val + ind 224 else: 225 idx = ind 226 return idx 227 228 229def unravel_index(idx, shape): 230 """Convert the flattened ind to the coordinate array 231 232 Parameters 233 ---------- 234 idx : int or tvm.expr.IntImm 235 The 1D index 236 237 shape : tuple of int 238 Shape of the tensor 239 240 Returns 241 ------- 242 indices : tuple of int or tvm.expr.IntImm 243 Corresponding coordinate of the 1D index 244 """ 245 idxd = tvm.indexdiv 246 idxm = tvm.indexmod 247 indices = [] 248 for i in range(len(shape) - 1, -1, -1): 249 indices.append(idxm(idx, shape[i])) 250 idx = idxd(idx, shape[i]) 251 indices = indices[::-1] 252 return indices 253 254 255def const_matrix(matrix, name="const_matrix"): 256 """convert a const numpy 2-dimensional matrix to tvm tensor 257 258 Parameters 259 ---------- 260 matrix: numpy.ndarray 261 Const input array 262 name: str, optional 263 The name of output op 264 265 Returns 266 ------- 267 tensor: Tensor 268 The created tensor 269 """ 270 row, col = matrix.shape 271 dtype = str(matrix.dtype) 272 idxm = tvm.indexmod 273 274 def select_array(i, j): 275 now = tvm.const(0.0, dtype) 276 for ii in range(row): 277 for jj in range(col): 278 now = tvm.expr.Select(tvm.all(idxm(i, row) == ii, idxm(j, col) == jj), 279 tvm.const(matrix[ii][jj], dtype), 280 now) 281 return now 282 283 return tvm.compute(matrix.shape, select_array, name=name) 284 285 286def get_max_power2_factor(n, max_value=None): 287 """Get max factor of n in power of 2. If max_value is specificed, max factor 288 value will be no more max_value, 289 290 Parameter 291 --------- 292 n : int 293 The input value 294 295 max_value : int, optional 296 The max value for the factor 297 298 Returns 299 ------- 300 factor : int 301 The max factor in power of 2. 302 """ 303 x = 1 304 while n % 2 == 0: 305 if max_value is not None and max_value < x * 2: 306 break 307 x *= 2 308 n /= 2 309 return x 310 311 312def get_shape(src_shape, src_layout, dst_layout): 313 """Given a source shape, a source layout and a destination layout, infer 314 the destination shape. 315 316 Parameter 317 --------- 318 src_shape : tuple of int or IntImm 319 Source shape 320 321 src_layout : str or Layout 322 Source layout 323 324 dst_layout : str or Layout 325 Destination layout 326 327 Returns 328 ------- 329 dst_shape : tuple of int 330 Destination shape 331 """ 332 if src_layout == dst_layout: 333 return get_const_tuple(src_shape) 334 335 if isinstance(src_layout, str): 336 src_layout = layout(src_layout) 337 if isinstance(dst_layout, str): 338 dst_layout = layout(dst_layout) 339 340 assert len(src_layout) == len(dst_layout), \ 341 "Incompatible layout %s vs %s" % (src_layout, dst_layout) 342 343 layout_mapping = bijective_layout(src_layout, dst_layout) 344 dst_indices = layout_mapping.forward_index( 345 tvm.convert([i for i in range(len(src_layout))])) 346 347 return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices])) 348