1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name
18"""Common topi utilities"""
19from __future__ import absolute_import as _abs
20from numbers import Integral
21
22import tvm
23from tvm.api import layout, bijective_layout
24from . import tag
25
26class InvalidShapeError(ValueError):
27    """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
28    pass
29
30def traverse_inline(s, final_op, callback):
31    """Traverse computation graph and do auto inline
32
33    Parameters
34    ----------
35    s: schedule
36        The schedule
37    final_op: Operation
38        The final output operator.
39    callback: callable
40        The callback function on each op
41    """
42    visited = set()
43
44    def _traverse(op):
45        if op in visited:
46            return
47        visited.add(op)
48        if tag.is_injective(op.tag):
49            if op not in s.outputs:
50                s[op].compute_inline()
51            for tensor in op.input_tensors:
52                if isinstance(tensor.op, tvm.tensor.ComputeOp):
53                    _traverse(tensor.op)
54        callback(op)
55
56    _traverse(final_op)
57
58
59def prod(x):
60    """Get the product of every items in the tuple.
61
62    Parameters
63    ----------
64    x: tuple
65        Input tuple
66
67    Returns
68    -------
69    value : Expr
70        The result value
71    """
72    if not x:
73        return tvm.const(1, "int32")
74    res = x[0]
75    for i in range(1, len(x)):
76        res = res * x[i]
77    return res
78
79
80def get_const_int(expr):
81    """Verifies expr is integer and get the constant value.
82
83    Parameters
84    ----------
85    expr : tvm.Expr or int
86        The input expression.
87
88    Returns
89    -------
90    out_value : int
91        The output.
92    """
93    if isinstance(expr, Integral):
94        return expr
95    if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
96        expr = tvm.ir_pass.Simplify(expr)
97    if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
98        raise ValueError("Expect value to be constant int")
99    return int(expr.value)
100
101
102def get_const_float(expr):
103    """Verifies expr is a floating point and get the constant value.
104
105    Parameters
106    ----------
107    expr : tvm.Expr or float
108        The input expression.
109
110    Returns
111    -------
112    out_value : float
113        The output.
114    """
115    if isinstance(expr, float):
116        return float(expr)
117    if not isinstance(expr, tvm.expr.FloatImm):
118        expr = tvm.ir_pass.Simplify(expr)
119    if not isinstance(expr, tvm.expr.FloatImm):
120        raise ValueError("Expect value to be constant float")
121    return float(expr.value)
122
123
124def equal_const_int(expr, value):
125    """Returns if expr equals value.
126
127    Parameters
128    ----------
129    expr : tvm.Expr
130        The input expression.
131
132    Returns
133    -------
134    equal : bool
135        Whether they equals.
136    """
137    if isinstance(expr, Integral):
138        return expr == value
139    if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
140        expr = tvm.ir_pass.Simplify(expr)
141    if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
142        return False
143    return expr.value == value
144
145
146def get_const_tuple(in_tuple):
147    """Verifies input tuple is IntImm or Var, returns tuple of int or Var.
148
149    Parameters
150    ----------
151    in_tuple : tuple of Expr
152        The input.
153
154    Returns
155    -------
156    out_tuple : tuple of int
157        The output.
158    """
159    ret = []
160    for elem in in_tuple:
161        if isinstance(elem, tvm.expr.Var):
162            ret.append(elem)
163        elif not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm, int)):
164            elem = tvm.ir_pass.Simplify(elem)
165            if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)):
166                ret.append(elem)
167        else:
168            ret.append(get_const_int(elem))
169    return tuple(ret)
170
171
172def get_float_tuple(in_tuple):
173    """Verifies input tuple is FloatImm, returns tuple of float.
174
175    Parameters
176    ----------
177    in_tuple : tuple of Expr
178        The input.
179
180    Returns
181    -------
182    out_tuple : tuple of float
183        The output.
184    """
185    return tuple(get_const_float(elem) for elem in in_tuple)
186
187
188def simplify(expr):
189    """Simplify the expression if it is Expr, directly return if it is int.
190
191    Parameters
192    ----------
193    expr : Expr or int
194        The input.
195
196    Returns
197    -------
198    out : Expr or int
199        The simplified output
200    """
201    return tvm.ir_pass.Simplify(expr) if isinstance(expr, tvm.expr.Expr) else expr
202
203
204def ravel_index(indices, shape):
205    """Flatten the index tuple to 1D
206
207    Parameters
208    ----------
209    indices : tuple of int or tvm.expr.IntImm
210        The input coordinates
211
212    shape : tuple of int
213        Shape of the tensor.
214
215    Returns
216    -------
217    idx : int or Expr
218        The index after flattening
219    """
220    idx = None
221    for i, (shape_val, ind) in enumerate(zip(shape, indices)):
222        if i != 0:
223            idx = idx * shape_val + ind
224        else:
225            idx = ind
226    return idx
227
228
229def unravel_index(idx, shape):
230    """Convert the flattened ind to the coordinate array
231
232    Parameters
233    ----------
234    idx : int or tvm.expr.IntImm
235        The 1D index
236
237    shape : tuple of int
238        Shape of the tensor
239
240    Returns
241    -------
242    indices : tuple of int or tvm.expr.IntImm
243        Corresponding coordinate of the 1D index
244    """
245    idxd = tvm.indexdiv
246    idxm = tvm.indexmod
247    indices = []
248    for i in range(len(shape) - 1, -1, -1):
249        indices.append(idxm(idx, shape[i]))
250        idx = idxd(idx, shape[i])
251    indices = indices[::-1]
252    return indices
253
254
255def const_matrix(matrix, name="const_matrix"):
256    """convert a const numpy 2-dimensional matrix to tvm tensor
257
258    Parameters
259    ----------
260    matrix: numpy.ndarray
261        Const input array
262    name: str, optional
263        The name of output op
264
265    Returns
266    -------
267    tensor: Tensor
268        The created tensor
269    """
270    row, col = matrix.shape
271    dtype = str(matrix.dtype)
272    idxm = tvm.indexmod
273
274    def select_array(i, j):
275        now = tvm.const(0.0, dtype)
276        for ii in range(row):
277            for jj in range(col):
278                now = tvm.expr.Select(tvm.all(idxm(i, row) == ii, idxm(j, col) == jj),
279                                      tvm.const(matrix[ii][jj], dtype),
280                                      now)
281        return now
282
283    return tvm.compute(matrix.shape, select_array, name=name)
284
285
286def get_max_power2_factor(n, max_value=None):
287    """Get max factor of n in power of 2. If max_value is specificed, max factor
288    value will be no more max_value,
289
290    Parameter
291    ---------
292    n : int
293        The input value
294
295    max_value : int, optional
296        The max value for the factor
297
298    Returns
299    -------
300    factor : int
301        The max factor in power of 2.
302    """
303    x = 1
304    while n % 2 == 0:
305        if max_value is not None and max_value < x * 2:
306            break
307        x *= 2
308        n /= 2
309    return x
310
311
312def get_shape(src_shape, src_layout, dst_layout):
313    """Given a source shape, a source layout and a destination layout, infer
314    the destination shape.
315
316    Parameter
317    ---------
318    src_shape : tuple of int or IntImm
319        Source shape
320
321    src_layout : str or Layout
322        Source layout
323
324    dst_layout : str or Layout
325        Destination layout
326
327    Returns
328    -------
329    dst_shape : tuple of int
330        Destination shape
331    """
332    if src_layout == dst_layout:
333        return get_const_tuple(src_shape)
334
335    if isinstance(src_layout, str):
336        src_layout = layout(src_layout)
337    if isinstance(dst_layout, str):
338        dst_layout = layout(dst_layout)
339
340    assert len(src_layout) == len(dst_layout), \
341        "Incompatible layout %s vs %s" % (src_layout, dst_layout)
342
343    layout_mapping = bijective_layout(src_layout, dst_layout)
344    dst_indices = layout_mapping.forward_index(
345        tvm.convert([i for i in range(len(src_layout))]))
346
347    return get_const_tuple(tuple([src_shape[i.value] for i in dst_indices]))
348