1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=no-else-return
18# pylint: disable=unidiomatic-typecheck
19"""
20This file contains the set of passes for Relay, which exposes an interface for
21configuring the passes and scripting them in Python.
22"""
23from tvm.ir import IRModule
24from tvm.relay import transform, build_module
25from tvm.runtime.ndarray import cpu
26
27from . import _ffi_api
28from .feature import Feature
29
30
31def context_analysis(mod, default_context):
32    """Analyze the device context information of each IR node in a Relay
33    program.
34
35    Parameters
36    ----------
37    mod : tvm.IRModule
38        The input module.
39
40    default_context : tvm.runtime.TVMContext
41        The default context allocated to an IR node.
42    """
43    return _ffi_api.ContextAnalysis(mod, default_context)
44
45
46def post_order_visit(expr, fvisit):
47    """Recursively visit the ir in post DFS order node,
48    apply fvisit. Each node is guaranteed to be visited
49    only once.
50
51    Parameters
52    ----------
53    expr : tvm.relay.Expr
54        The input expression.
55
56    fvisit : function
57        The visitor function to be applied.
58    """
59    return _ffi_api.post_order_visit(expr, fvisit)
60
61
62def well_formed(expr):
63    """Check that each Var is only bound once (well formed).
64
65    Parameters
66    ----------
67    expr : tvm.relay.Expr
68        The input expression
69
70    Returns
71    -------
72    well_form : bool
73        Whether the input expression is well formed
74    """
75    return _ffi_api.well_formed(expr)
76
77
78def check_kind(t, mod=None):
79    """Check that the type is well kinded and return the kind.
80    For example, this mean type cannot has tensor of tensor, or is a tuple type
81    of 2 shapes.
82
83    Parameters
84    ----------
85    t : tvm.relay.Type
86        The type to check
87
88    mod : Optional[tvm.IRModule]
89        The global module.
90
91    Returns
92    -------
93    kind : Kind
94        the kind of t
95
96    Examples
97    --------
98    .. code:: python
99
100        assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) == Shape
101        assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type
102    """
103    if mod is not None:
104        return _ffi_api.check_kind(t, mod)
105    else:
106        return _ffi_api.check_kind(t)
107
108
109def check_constant(expr):
110    """Check whether an expression is constant
111
112    Parameters
113    ----------
114    expr : tvm.relay.Expr
115        The input expression
116
117    Returns
118    -------
119    result : bool
120        Whether the expression is constant.
121    """
122    return _ffi_api.check_constant(expr)
123
124
125def check_basic_block_normal_form(expr):
126    """Check whether an expression is in the basic block form
127
128    Parameters
129    ----------
130    expr : tvm.relay.Expr
131        The input expression
132
133    Returns
134    -------
135    result : bool
136        Whether the expression is in the basic block form.
137    """
138    return _ffi_api.check_basic_block_normal_form(expr)
139
140
141def free_vars(expr):
142    """Get free Vars from expression expr in Post DFS order.
143
144    Parameters
145    ----------
146    expr : tvm.relay.Expr
147        The input expression
148
149    Returns
150    -------
151    free : List[tvm.relay.Var]
152        The list of free variables in post DFS order.
153
154    Note
155    ----
156    The fact that Vars are post-DFS ordred are useful in
157    neural networks: usually this means weights of previous
158    are ordered first.
159    """
160    return _ffi_api.free_vars(expr)
161
162
163def bound_vars(expr):
164    """Get bound vars from expression expr in post-DFS order.
165
166    Parameters
167    ----------
168    expr : tvm.relay.Expr
169        The input expression
170
171    Returns
172    -------
173    free : List[tvm.relay.Var]
174        The list of bound variables in post-DFS order.
175    """
176    return _ffi_api.bound_vars(expr)
177
178
179def all_vars(expr):
180    """Get all vars from expression expr in post-DFS order.
181
182    Parameters
183    ----------
184    expr : tvm.relay.Expr
185        The input expression
186
187    Returns
188    -------
189    free : List[tvm.relay.Var]
190        The list of all variables in post-DFS order.
191    """
192    return _ffi_api.all_vars(expr)
193
194
195def free_type_vars(expr, mod=None):
196    """Get free type variables from expression/type e
197
198    Parameters
199    ----------
200    expr : Union[tvm.relay.Expr,tvm.relay.Type]
201        The input expression/type
202
203    mod : Optional[tvm.IRModule]
204        The global module
205
206    Returns
207    -------
208    free : List[tvm.relay.TypeVar]
209        The list of free type variables in post-DFS order
210    """
211    use_mod = mod if mod is not None else IRModule()
212    return _ffi_api.free_type_vars(expr, use_mod)
213
214
215def bound_type_vars(expr, mod=None):
216    """Get bound type variables from expression/type e
217
218    Parameters
219    ----------
220    expr : Union[tvm.relay.Expr,tvm.relay.Type]
221        The input expression/type
222
223    mod : Optional[tvm.IRModule]
224        The global module
225
226    Returns
227    -------
228    free : List[tvm.relay.TypeVar]
229        The list of bound type variables in post-DFS order
230    """
231    use_mod = mod if mod is not None else IRModule()
232    return _ffi_api.bound_type_vars(expr, use_mod)
233
234
235def all_type_vars(expr, mod=None):
236    """Get all type variables from expression/type e
237
238    Parameters
239    ----------
240    expr : Union[tvm.relay.Expr,tvm.relay.Type]
241        The input expression/type
242
243    mod : Optional[tvm.IRModule]
244        The global module
245
246    Returns
247    -------
248    free : List[tvm.relay.TypeVar]
249        The list of all type variables in post-DFS order
250    """
251    use_mod = mod if mod is not None else IRModule()
252    return _ffi_api.all_type_vars(expr, use_mod)
253
254
255def all_dtypes(expr):
256    """Collect set of all data types used in `expr`.
257
258    Parameters
259    ----------
260    expr : tvm.relay.Expr
261        The input expression
262
263    Returns
264    -------
265    ret : Set[String]
266        Set of data types used in the expression (e.g., `{'int8', 'int32'}`)
267    """
268    return set(_ffi_api.all_dtypes(expr))
269
270
271def collect_device_info(expr):
272    """Collect the device allocation map for the given expression. The device
273    ids are propagated from the `device_copy` operators.
274
275    Parameters
276    ----------
277    expr : tvm.relay.Expr
278        The input expression.
279
280    Returns
281    -------
282    ret : Dict[tvm.relay.ir.expr, int]
283        A dictionary mapping tvm.relay.Expr to device type.
284    """
285    return _ffi_api.CollectDeviceInfo(expr)
286
287
288def collect_device_annotation_ops(expr):
289    """Collect the device annotation ops for the given expression.
290
291    Parameters
292    ----------
293    expr : tvm.relay.Expr
294        The input expression.
295
296    Returns
297    -------
298    ret : Dict[tvm.relay.Expr, int]
299        A dictionary mapping tvm.relay.Expr to device type where the keys are
300        annotation expressions.
301    """
302    return _ffi_api.CollectDeviceAnnotationOps(expr)
303
304
305def get_total_mac_number(expr):
306    """
307    Count the number of MACs (multiply-accumulate) of a model
308
309    Parameters
310    ----------
311    expr : tvm.relay.Expr
312        The input expression.
313
314    Returns
315    -------
316    result : int64
317      The number of MACs (multiply-accumulate) of a model
318    """
319    return _ffi_api.GetTotalMacNumber(expr)
320
321
322def unmatched_cases(match, mod=None):
323    """
324    Finds cases that the match expression does not catch, if any.
325
326    Parameters
327    ----------
328    match : tvm.relay.Match
329        The match expression
330
331    mod : Optional[tvm.IRModule]
332        The module (defaults to an empty module)
333
334    Returns
335    -------
336    missing_patterns : [tvm.relay.Pattern]
337        Patterns that the match expression does not catch.
338    """
339    return _ffi_api.unmatched_cases(match, mod)
340
341
342def detect_feature(a, b=None):
343    """
344    Detect the feature used in a relay program.
345
346    Parameters
347    ----------
348    a : Union[tvm.relay.Expr, tvm.IRModule]
349      The input expression or module.
350
351    b : Optional[Union[tvm.relay.Expr, tvm.IRModule]]
352      The input expression or module.
353      The two arguments cannot both be expression or module.
354
355    Returns
356    -------
357    features : Set[Feature]
358      Features used in the program.
359    """
360    if isinstance(a, IRModule):
361        a, b = b, a
362    return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)}
363
364
365def extract_fused_functions(mod):
366    """Pass to extract IRModule of only fused primitive functions.
367
368    The ExtractFusedFunctions pass invokes SimplifyInference, FuseOps(3),
369    and ExtractFusedFunctions in that order
370
371    Parameters
372    ----------
373    mod : tvm.relay.IRModule
374
375    Returns
376    -------
377    ret : Dict[int, tvm.relay.function.Function]
378        A module containing only fused primitive functions
379    """
380    ret_mod = _ffi_api.ExtractFusedFunctions()(mod)
381    ret = {}
382    for hash_, func in ret_mod.functions.items():
383        ret[hash_] = func
384    return ret
385
386
387def search_fc_transpose(expr):
388    """Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0]))
389
390    This function is used in the data_dep_optimization.simplify_fc_transpose method
391
392    Parameters
393    ----------
394    expr : tvm.relay.Expr
395
396    Returns
397    -------
398    ret : Array[String]
399        Array of weight variable name in pattern y = nn.dense(x, transpose(w, [1, 0]))
400    """
401    ret = _ffi_api.search_fc_transpose(expr)
402    return ret
403
404
405def get_calibration_data(mod, data):
406    """Get the calibration data of a given relay graph
407
408    This pass uses the graph runtime to get the calibration data of a module, which
409    includes the input and output values of each function. The returned data uses
410    the GlobalVar of each function as a key. Users can further access the inputs and
411    outputs by using `inputs` or  `outputs` as the key.
412
413    Following are some limitations:
414    1. The input module (graph) cannot have control flows.
415    2. The input arguments of each function cannot be tuples (outputs can be tuples).
416    3. We only handle top-level functions (i.e., nested function is not handled).
417    4. We only handle functions with `Compiler` attribute being set.
418
419    Parameters
420    ----------
421    mod : tvm.IRModule
422        The input module for collecting the calibration data
423
424    data : Dict[str, NDArray]
425        The input data for running the module
426
427    Returns
428    -------
429    data : Dict[tvm.relay.GlobalVar, Dict[str, NDArray]]
430    """
431    output_map = _ffi_api.get_calibrate_output_map(mod)
432
433    mod = _ffi_api.get_calibrate_module(mod)
434    mod = transform.Inline()(mod)
435
436    ref_ex = build_module.create_executor("graph", mod=mod, ctx=cpu(0))
437    ref_res = ref_ex.evaluate()(**data)
438
439    calib_data = {}
440    for gvar, indices in output_map.items():
441        offset = int(indices[0])
442        in_len = int(indices[1])
443        out_len = int(indices[2])
444        value = {
445            "inputs": ref_res[offset : offset + in_len],
446            "outputs": ref_res[offset + in_len : offset + in_len + out_len],
447        }
448        calib_data[gvar] = value
449
450    return calib_data
451