1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name,arguments-differ,no-else-return,unused-argument,missing-docstring
18"""
19Relay pass transformation infrastructure.
20"""
21import types
22import inspect
23import functools
24
25import tvm
26from tvm._ffi.runtime_ctypes import TVMContext
27from tvm import relay
28from . import _transform
29from .base import RelayNode, register_relay_node
30from .. import nd as _nd
31
32
33@register_relay_node
34class PassInfo(RelayNode):
35    """The class contains the meta data required by a pass. It is the
36    container of information needed by running an optimization or analysis.
37    This class can be extended by adding new members when more meta data is
38    needed.
39
40    Parameters
41    ----------
42    opt_level : int
43        The optimization level of this pass.
44
45    name : str
46        The pass name.
47
48    required : List[str]
49        The list of passes that are required by a certain pass.
50    """
51
52    def __init__(self, opt_level, name, required=None):
53        self.__init_handle_by_constructor__(
54            _transform.PassInfo, opt_level, name, required)
55
56
57@register_relay_node
58class PassContext(RelayNode):
59    """The basis where a Relay optimization/analysis runs on.
60    Each pass context contains a number of auxiliary information that is used
61    to help an optimization pass. Such information includes the error reporter
62    to record the errors of during the optimization, etc.
63
64    opt_level : Optional[int]
65        The optimization level of this pass.
66
67    fallback_device : Optional[Union[int, str, TVMContext]]
68        The fallback device type. It is also used as the default device for
69        operators that are not annotated during heterogeneous execution.
70
71    required_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
72        The list of passes that are required by a certain pass.
73
74    disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]]
75        The list of passes that are disabled.
76    """
77    def __init__(self,
78                 opt_level=2,
79                 fallback_device=_nd.cpu(),
80                 required_pass=None,
81                 disabled_pass=None):
82        if isinstance(fallback_device, str):
83            fallback_device = _nd.context(fallback_device).device_type
84        elif isinstance(fallback_device, TVMContext):
85            fallback_device = fallback_device.device_type
86        if not isinstance(fallback_device, int):
87            raise TypeError("required_pass is expected to be the type of " +
88                            "int/str/TVMContext.")
89
90        required = list(required_pass) if required_pass else []
91        if not isinstance(required, (list, tuple)):
92            raise TypeError("required_pass is expected to be the type of " +
93                            "list/tuple/set.")
94
95        disabled = list(disabled_pass) if disabled_pass else []
96        if not isinstance(disabled, (list, tuple)):
97            raise TypeError("disabled_pass is expected to be the type of " +
98                            "list/tuple/set.")
99
100        self.__init_handle_by_constructor__(_transform.PassContext, opt_level,
101                                            fallback_device, required,
102                                            disabled)
103
104    def __enter__(self):
105        _transform.EnterPassContext(self)
106        return self
107
108    def __exit__(self, ptype, value, trace):
109        _transform.ExitPassContext(self)
110
111    @staticmethod
112    def current():
113        """Return the current pass context."""
114        return _transform.GetCurrentPassContext()
115
116
117def build_config(opt_level=2,
118                 fallback_device=_nd.cpu(),
119                 required_pass=None,
120                 disabled_pass=None):
121    """Configure the build behavior by setting config variables.
122
123    Parameters
124    ----------
125    opt_level: int, optional
126        Optimization level. The optimization pass name and level are as the
127        following:
128
129        .. code-block:: python
130
131            OPT_PASS_LEVEL = {
132                "SimplifyInference": 0,
133                "OpFusion": 1,
134                "FoldConstant": 2,
135                "FoldScaleAxis": 3,
136                "AlterOpLayout": 3,
137                "CanonicalizeOps": 3,
138                "CanonicalizeCast": 3,
139                "EliminateCommonSubexpr": 3,
140                "CombineParallelConv2D": 4,
141                "CombineParallelDense": 4
142            }
143
144    fallback_device : int, str, or tvm.TVMContext, optional
145        The fallback device. It is also used as the default device for
146        operators without specified device during heterogeneous execution.
147
148    required_pass: set of str, optional
149        Optimization passes that are required regardless of optimization level.
150
151    disabled_pass: set of str, optional
152        Optimization passes to be disabled during optimization.
153
154    Returns
155    -------
156    pass_context: PassContext
157        The pass context for optimizations.
158    """
159    return PassContext(opt_level, fallback_device, required_pass,
160                       disabled_pass)
161
162
163@register_relay_node
164class Pass(RelayNode):
165    """The base class of all passes. All methods here are just simple wrappers
166    that are implemented in the backend. They are defined for users to
167    conveniently interact with the base class.
168    """
169
170    @property
171    def info(self):
172        """Get the pass meta."""
173        return _transform.Info(self)
174
175    def __call__(self, mod):
176        """Execute the pass. Note that for sequential pass, the dependency among
177        different passes will be resolved in the backend.
178
179        Parameters
180        ----------
181        mod : tvm.relay.Module
182            The module that a certain optimization is performed on.
183
184        Returns
185        -------
186        mod : tvm.relay.Module
187            The updated module after applying this pass.
188        """
189        return _transform.RunPass(self, mod)
190
191
192@register_relay_node
193class ModulePass(Pass):
194    """A pass that works on tvm.relay.Module. Users don't need to interact with
195    this class directly. Instead, a module pass should be created through
196    `module_pass`, because the design of the `module_pass` API is flexible
197    enough to handle the creation of a module pass in different manners. In
198    addition, all members of a module pass can be accessed from the base class.
199    The same rule applies to FunctionPass as well.
200    """
201
202
203@register_relay_node
204class FunctionPass(Pass):
205    """A pass that works on each tvm.relay.Function in a module. A function
206    pass class should be created through `function_pass`.
207    """
208
209
210@register_relay_node
211class Sequential(Pass):
212    """A pass that works on a sequence of pass objects. Multiple passes can be
213    executed sequentially using this class.
214
215    Some typical usage of the sequential pass are:
216    1. Users provide a list of passes for optimization.
217    2. Only an optimization level is provided so that the backend system has
218       to glob all passes at this level and below to perform the optimizations.
219    Note that users can also provide a series of passes that they don't want to
220    apply when running a sequential pass. Pass dependency will be resolved in
221    the backend as well.
222
223    Parameters
224    ----------
225    passes : Optional[List[Pass]]
226        A sequence of passes candidate for optimization.
227
228    opt_level : Optional[int]
229        The optimization level of this sequential pass.
230
231    name : Optional[str]
232        The name of the sequential pass.
233
234    required : Optional[List[str]]
235        The list of passes that the sequential pass is dependent on.
236    """
237
238    def __init__(self,
239                 passes=None,
240                 opt_level=2,
241                 name="sequential",
242                 required=None):
243        passes = passes if passes else []
244        if not isinstance(passes, (list, tuple)):
245            raise TypeError("passes must be a list of Pass objects.")
246
247        required = required if required else []
248        if not isinstance(required, (list, tuple)):
249            raise TypeError("Required is expected to be the type of list/tuple.")
250
251        self.__init_handle_by_constructor__(_transform.Sequential,
252                                            passes, opt_level, name, required)
253
254
255def InferType():
256    """Infer the type of an expr.
257
258    Returns
259    -------
260    ret : tvm.relay.Pass
261        The registered type inference pass.
262    """
263    return _transform.InferType()
264
265
266def FoldScaleAxis():
267    """Fold the scaling of axis into weights of conv2d/dense. This pass will
268    invoke both forward and backward scale folding.
269
270    Returns
271    -------
272    ret : tvm.relay.Pass
273        The registered pass to fold expressions.
274
275    Note
276    ----
277    Internally, we will call backward_fold_scale_axis before using
278    forward_fold_scale_axis as backward folding targets the common conv->bn
279    pattern.
280    """
281    return _transform.FoldScaleAxis()
282
283
284def BackwardFoldScaleAxis():
285    """Backward fold axis scaling into weights of conv2d/dense.
286
287    Returns
288    -------
289    ret : tvm.relay.Pass
290        The registered pass to backward fold expressions.
291
292    Note
293    ----
294    It is recommended to call backward_fold_scale_axis
295    before using forward_fold_scale_axis as backward folding targets the common
296    conv->bn pattern.
297    """
298    return _transform.BackwardFoldScaleAxis()
299
300def RemoveUnusedFunctions(entry_functions=None):
301    """Remove unused global relay functions in a relay module.
302
303    Parameters
304    ----------
305    entry_functions: list[string]
306        The set of entry functions to start from.
307
308    Returns
309    -------
310    ret : tvm.relay.Pass
311        The registered pass to remove unused functions.
312    """
313    if entry_functions is None:
314        entry_functions = ['main']
315    return _transform.RemoveUnusedFunctions(entry_functions)
316
317def ForwardFoldScaleAxis():
318    """Fold the scaling of axis into weights of conv2d/dense.
319
320    Returns
321    -------
322    ret : tvm.relay.Pass
323        The registered pass to forward fold expressions.
324
325    Note
326    ----
327    It is recommended to call backward_fold_scale_axis
328    before using forward_fold_scale_axis, as backward folding targets the
329    common conv->bn pattern.
330    """
331    return _transform.ForwardFoldScaleAxis()
332
333
334def SimplifyInference():
335    """Simplify the data-flow graph for inference phase. An simplified expression
336    which is semantically equal to the input expression will be returned.
337
338    Returns
339    -------
340    ret: tvm.relay.Pass
341        The registered to perform operator simplification.
342    """
343    return _transform.SimplifyInference()
344
345
346def CanonicalizeOps():
347    """Canonicalize special operators to basic operators.
348    This can simplify followed analysis, e.g. expanding bias_add to
349    expand_dims and broadcast_add.
350
351    Returns
352    -------
353    ret: tvm.relay.Pass
354        The registered pass performing the canonicalization.
355    """
356    return _transform.CanonicalizeOps()
357
358
359def DeadCodeElimination(inline_once=False):
360    """Remove expressions that do not have any users (dead code).
361
362    Parameters
363    ----------
364    inline_once: Optional[Bool]
365        Whether to inline binding that occurs only once.
366
367    Returns
368    -------
369    ret: tvm.relay.Pass
370        The registered pass that eliminates the dead code in a Relay program.
371    """
372    return _transform.DeadCodeElimination(inline_once)
373
374
375def FoldConstant():
376    """Fold the constant expressions in a Relay program.
377
378    Returns
379    -------
380    ret : tvm.relay.Pass
381        The registered pass for constant folding.
382    """
383    return _transform.FoldConstant()
384
385
386def FuseOps(fuse_opt_level=-1):
387    """Fuse operators in an expr to a larger operator according to some rules.
388
389    Parameters
390    ----------
391    fuse_opt_level : int
392        The level of fuse optimization. -1 indicates that the level will be
393        inferred from pass context.
394
395    Returns
396    -------
397    ret : tvm.relay.Pass
398        The registered pass for operator fusion.
399    """
400    return _transform.FuseOps(fuse_opt_level)
401
402
403def CombineParallelConv2D(min_num_branches=3):
404    """Combine multiple conv2d operators into one.
405
406    Parameters
407    ----------
408    min_num_branches : int
409        The minimum number of required parallel branches for performing this
410        optimization.
411
412    Returns
413    -------
414    ret: tvm.relay.Pass
415        The registered pass that combines parallel conv2d operators.
416    """
417    return _transform.CombineParallelConv2D(min_num_branches)
418
419
420def CombineParallelDense(min_num_branches=3):
421    """Combine multiple dense operators into one. For example:
422
423                data
424          /              \
425     dense (2,2)         dense (2,2)
426         |                 |
427    elemwise/bcast (2,2)  elemwise/bcast (2,2)
428
429    Would become:
430
431             data
432              |
433        batch_matmul+elemwise/bcast (2,2,2)
434
435    Parameters
436    ----------
437    min_num_branches : int
438        The minimum number of required parallel branches for performing this
439        optimization.
440
441    Returns
442    -------
443    ret: tvm.relay.Pass
444        The registered pass that combines parallel dense operators.
445    """
446    return _transform.CombineParallelDense(min_num_branches)
447
448
449def AlterOpLayout():
450    """Alternate the layouts of operators or replace primitive operators with
451    other expressions.
452    This pass can be used for computing convolution in custom layouts or
453    other general weight pre-transformation.
454
455    Returns
456    -------
457    ret : tvm.relay.Pass
458        The registered pass that alters the layout of operators.
459    """
460    return _transform.AlterOpLayout()
461
462
463def Legalize(legalize_map_attr_name="FTVMLegalize"):
464    """Legalizes an expression with another expression.
465    This pass can be used to replace an expr with another expr for target
466    dependent optimizations. For example, one expr, though semnatically
467    equivalent to the other, can have better performance on a target. This pass
468    can be used to legalize the expr in a target-dependent manner.
469
470    Parameters
471    ----------
472    legalize_map_attr_name : str
473        The Op's attr name which corresponds to the legalize rule function.
474
475    Returns
476    -------
477    ret : tvm.relay.Pass
478        The registered pass that rewrites an expr.
479    """
480    return _transform.Legalize(legalize_map_attr_name)
481
482
483def RewriteAnnotatedOps(fallback_device):
484    """Rewrite the annotated program where annotation operators, e.g.
485    `on_deivce`, mark which device an expression should be scheduled to.
486    This pass helps heterogeneous execution where different operators may need
487    to be allocated on various devices.
488
489    Parameters
490    ----------
491    fallback_device : int
492        The fallback device type. It is also used as the default device for
493        operators with no annotated device.
494
495    Returns
496    -------
497    ret: tvm.relay.Pass
498        The registered pass that rewrites an expression with annotated
499        `on_device` operators.
500    """
501    return _transform.RewriteDeviceAnnotation(fallback_device)
502
503
504def ToANormalForm():
505    """Turn Graph Normal Form expression into A Normal Form Expression.
506    The scope of the root expression is the global scope.
507    The scope of any non root expression is the least common ancestor of all it's scope.
508    Values are ordered by post-DFS order in each scope.
509
510    Returns
511    -------
512    ret: Union[tvm.relay.Pass, tvm.relay.Expr]
513        The registered pass that transforms an expression into A Normal Form.
514    """
515    return _transform.ToANormalForm()
516
517
518def ToCPS(expr, mod=None):
519    """
520    Turn expression into continuation passing style(CPS).
521
522    Every intermediate compute will be passed to a continuation.
523
524    Returns
525    -------
526    result: tvm.relay.Pass
527        The registered pass that transforms an expression into CPS.
528    """
529    return _transform.to_cps(expr, mod)
530
531
532def EtaExpand(expand_constructor=False, expand_global_var=False):
533    """Add abstraction over a constructor or global variable bound to a function
534
535    Parameters
536    ----------
537    expand_constructor: bool
538        Whether to expand constructors.
539
540    expand_global_var: bool
541        Whether to expand global variables.
542
543    Returns
544    -------
545    ret: tvm.relay.Pass
546        The registered pass that eta expands an expression.
547    """
548    return _transform.EtaExpand(expand_constructor, expand_global_var)
549
550
551def ToGraphNormalForm():
552    """Turn a Relay program in A Normal Form into Graph Normal Form
553
554    Returns
555    -------
556    ret : tvm.relay.Pass
557        The registered pass that transforms an expression into Graph Normal Form.
558    """
559    return _transform.ToGraphNormalForm()
560
561
562def EliminateCommonSubexpr(fskip=None):
563    """Eliminate common subexpressions.
564
565    Parameters
566    ----------
567    fskip: Callable
568        The callback function that decides whether an expression should be
569        skipped.
570
571    Returns
572    -------
573    ret : tvm.relay.Pass
574        The registered pass that eliminates common subexpressions.
575    """
576    return _transform.EliminateCommonSubexpr(fskip)
577
578
579def PartialEvaluate():
580    """Evaluate the static fragment of the code.
581
582    Note
583    ----
584    This transformation could be either `Module -> Module` or `Expr -> Expr`.
585    It will directly transform the input expression to a new one if the target
586    expression is provided. Otherwise, it will rely on the pass manager to
587    carry out transformation.
588
589    Returns
590    -------
591    ret: tvm.relay.Pass
592        The registered pass that performs partial evaluation on an expression.
593    """
594    return _transform.PartialEvaluate()
595
596
597def CanonicalizeCast():
598    """
599    Canonicalize cast expressions to make operator fusion more efficient.
600
601    Returns
602    -------
603    ret : tvm.relay.Pass
604        The registered pass that canonicalizes cast expression.
605    """
606    return _transform.CanonicalizeCast()
607
608
609def LambdaLift():
610    """
611    Lift the closure to global function.
612
613    Returns
614    -------
615    ret : tvm.relay.Pass
616        The registered pass that lifts the lambda function.
617    """
618    return _transform.LambdaLift()
619
620
621def PrintIR(show_meta_data=True):
622    """
623    Print the IR for a module to help debugging.
624
625    Parameters
626    ----------
627    show_meta_data : bool
628        A boolean flag to indicate if meta data should be printed.
629
630    Returns
631    -------
632    ret : tvm.relay.Pass
633        The registered pass that prints the module IR.
634    """
635    return _transform.PrintIR(show_meta_data)
636
637
638def gradient(expr, mod=None, mode='higher_order'):
639    """
640    Transform the input function,
641    returning a function that calculate the original result,
642    paired with gradient of the input.
643
644    Parameters
645    ----------
646    expr : tvm.relay.Expr
647        The input expression, which is a Function or a GlobalVar.
648
649    mod : Optional[tvm.relay.Module]
650
651    mode : Optional[String]
652        The mode of the automatic differentiation algorithm.
653        'first_order' only works on first order code, but will not produce
654        reference nor closure.
655        'higher_order' works on all code using reference and closure.
656
657    Returns
658    -------
659    expr : tvm.relay.Expr
660      The transformed expression.
661    """
662    if mode == 'first_order':
663        return _transform.first_order_gradient(expr, mod)
664    if mode == 'higher_order':
665        return _transform.gradient(expr, mod)
666    raise Exception('unknown mode')
667
668
669def to_cps(func, mod=None):
670    """
671    Turn expression into CPS expression.
672
673    Every intermediate compute will be passed to a continuation.
674
675    Parameters
676    ----------
677    func: tvm.relay.Function
678        The input function.
679
680    mod: Optional[tvm.relay.Module]
681        The global module.
682
683    Returns
684    -------
685    result: tvm.relay.Function
686      The output function.
687    """
688    return _transform.to_cps(func, mod)
689
690
691def un_cps(func):
692    """
693    Turn an cps function into a Function without the continuation argument.
694
695    Note that this will not give the exact same interface as before cps:
696      If the input/output is higher order, they will still be in cps form.
697
698    Parameters
699    ----------
700    func: tvm.relay.Function
701        The input function
702
703    Returns
704    -------
705    result: tvm.relay.Function
706        The output function
707    """
708    return _transform.un_cps(func)
709
710
711def _wrap_class_module_pass(pass_cls, pass_info):
712    """Wrap a python class as function pass"""
713    class PyModulePass(ModulePass):
714        """Internal wrapper class to create a class instance."""
715        def __init__(self, *args, **kwargs):
716            # initialize handle in cass pass_cls creation failed.fg
717            self.handle = None
718            inst = pass_cls(*args, **kwargs)
719            # it is important not to capture self to
720            # avoid a cyclic dependency
721            def _pass_func(mod, ctx):
722                return inst.transform_module(mod, ctx)
723            self.__init_handle_by_constructor__(
724                _transform.MakeModulePass, _pass_func, pass_info)
725            self._inst = inst
726
727        def __getattr__(self, name):
728            # fall back to instance attribute if there is not any
729            return self._inst.__getattribute__(name)
730
731    functools.update_wrapper(PyModulePass.__init__, pass_cls.__init__)
732    PyModulePass.__name__ = pass_cls.__name__
733    PyModulePass.__doc__ = pass_cls.__doc__
734    PyModulePass.__module__ = pass_cls.__module__
735    return PyModulePass
736
737
738def module_pass(pass_func=None, opt_level=None, name=None, required=None):
739    """Decorate a module pass.
740
741    This function returns a callback when pass_func is provided.
742    Otherwise, it serves a decorator function.
743
744    pass_func can also be a class type with a method transform_module.
745    This function will create a decorated ModulePass using transform_module
746    as the pass function.
747
748    Parameters
749    ----------
750    pass_func : Optional[Callable[(Module, PassContext) ->Module]]
751        The transformation function or class.
752
753    opt_level : int
754        The optimization level of this module pass.
755
756    name : Optional[str]
757        The name of the module pass. The name could be empty. In this case, the
758        name of the optimization function will be used as the pass name.
759
760    required : Optional[List[str]]
761        The list of passes that the module pass is dependent on.
762
763    Returns
764    -------
765    create_module_pass : Union[Callable, ModulePass]
766        A decorator will be returned if pass_func is not provided,
767        otherwise return the decorated result.
768        The returned decorator has two behaviors depending on the input:
769        A new ModulePass will be returned when we decorate a pass function.
770        A new ModulePass class will be returned when we decorate a class type.
771
772    Examples
773    --------
774    The following code block decorates a module pass class.
775
776    .. code-block:: python
777
778        @relay.transform.module_pass
779        class CustomPipeline:
780            def __init__(self, enable_fold):
781                self.enable_fold = enable_fold
782                self.cse = relay.transform.EliminateCommonSubexpr()
783                self.const_fold = relay.transform.FoldConstant()
784
785            def transform_module(self, mod, ctx):
786                mod = self.cse(mod, ctx)
787                if self.enable_fold:
788                    mod = self.const_fold(mod, ctx)
789                return mod
790
791        # create an instance of customized pipeline
792        pipeline = CustomPipeline(enable_fold=False)
793        assert isinstance(pipeline, transform.ModulePass)
794        # run the pipeline.
795        output_module = pipeline(input_module)
796
797    The following code creates a module pass by decorating
798    a user defined transform function.
799
800    .. code-block:: python
801
802        @relay.transform.module_pass(opt_level=2)
803        def transform(mod, ctx):
804            tp = relay.TensorType((10,), "float32")
805            x = relay.var("x", tp)
806            gv = relay.GlobalVar("var")
807            func = relay.Function([x], relay.abs(x))
808            new_mod = relay.Module({gv: func})
809            new_mod.update(mod)
810            return new_mod
811
812        module_pass = transform
813        assert isinstance(module_pass, transform.ModulePass)
814        assert module_pass.info.opt_level == 2
815
816        # Given a module m, the optimization could be invoked as the follwoing:
817        updated_mod = module_pass(m)
818        # Now a function abs should be added to the module m.
819    """
820    if opt_level is None:
821        raise ValueError("Please provide opt_level for the module pass.")
822
823    required = required if required else []
824    if not isinstance(required, (list, tuple)):
825        raise TypeError("Required is expected to be the type of " +
826                        "list/tuple.")
827
828    def create_module_pass(pass_arg):
829        """Internal function that creates a module pass"""
830        fname = name if name else pass_arg.__name__
831        info = PassInfo(opt_level, fname, required)
832        if inspect.isclass(pass_arg):
833            return _wrap_class_module_pass(pass_arg, info)
834        if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
835            raise TypeError("pass_func must be a callable for Module pass")
836        return _transform.MakeModulePass(pass_arg, info)
837
838    if pass_func:
839        return create_module_pass(pass_func)
840    return create_module_pass
841
842
843def _wrap_class_function_pass(pass_cls, pass_info):
844    """Wrap a python class as function pass"""
845    class PyFunctionPass(FunctionPass):
846        """Internal wrapper class to create a class instance."""
847        def __init__(self, *args, **kwargs):
848            # initialize handle in cass pass_cls creation failed.fg
849            self.handle = None
850            inst = pass_cls(*args, **kwargs)
851            # it is important not to capture self to
852            # avoid a cyclic dependency
853            def _pass_func(func, mod, ctx):
854                return inst.transform_function(func, mod, ctx)
855            self.__init_handle_by_constructor__(
856                _transform.MakeFunctionPass, _pass_func, pass_info)
857            self._inst = inst
858
859        def __getattr__(self, name):
860            # fall back to instance attribute if there is not any
861            return self._inst.__getattribute__(name)
862
863    functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__)
864    PyFunctionPass.__name__ = pass_cls.__name__
865    PyFunctionPass.__doc__ = pass_cls.__doc__
866    PyFunctionPass.__module__ = pass_cls.__module__
867    return PyFunctionPass
868
869
870def function_pass(pass_func=None, opt_level=None, name=None, required=None):
871    """Decorate a function pass.
872
873    This function returns a callback when pass_func
874    is provided. Otherwise, it returns the created function pass using the
875    given optimization function.
876
877    Parameters
878    ----------
879    pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]]
880        The transformation function or class.
881
882    opt_level : int
883        The optimization level of this module pass.
884
885    name : Optional[str]
886        The name of the function pass. The name could be empty. In this case, the
887        name of the optimization function will be used as the pass name.
888
889    required : Optional[List[str]]
890        The list of passes that the module pass is dependent on.
891
892    Returns
893    -------
894    create_function_pass : Union[Callable, FunctionPass]
895
896        A decorator will be returned if pass_func is not provided,
897        otherwise return the decorated result.
898        The returned decorator has two behaviors depending on the input:
899        A new FunctionPass will be returned when we decorate a pass function.
900        A new FunctionPass class will be returned when we decorate a class type.
901
902    Examples
903    --------
904    The following code block decorates a function pass class.
905
906    .. code-block:: python
907
908        @relay.transform.function_pass(opt_level=1)
909        class TestReplaceFunc:
910            def __init__(self, new_func):
911                self.new_func = new_func
912
913            def transform_function(self, func, mod, ctx):
914                # just for demo purposes
915                # transform func to new_func
916                return self.new_func
917
918        x = relay.var("x", shape=(10, 20))
919        f1 = relay.Function([x], x)
920        f2 = relay.Function([x], relay.log(x))
921        # fpass is now a special pass that replaces every
922        # function to f1
923        fpass = TestReplaceFunc(f1)
924        # now every function in input_mod is replaced by f1
925        res_mod = fpass(input_mod)
926
927
928    The following code creates a function pass by decorating
929    a user defined transform function.
930
931    .. code-block:: python
932
933        @relay.transform.function_pass(opt_level=2)
934        def transform(func, mod, ctx):
935            # my transformations here.
936            return func
937
938        function_pass = transform
939        assert isinstance(function_pass, transform.FunctionPass)
940        assert function_pass.info.opt_level == 2
941
942        # Given a module m, the optimization could be invoked as the follwoing:
943        updated_mod = function_pass(m)
944        # Now constant folding should have been applied to every function in
945        # the provided module m. And the updated module will be returned.
946    """
947
948    if opt_level is None:
949        raise ValueError("Please provide opt_level for the funtion pass.")
950
951    required = required if required else []
952    if not isinstance(required, (list, tuple)):
953        raise TypeError("Required is expected to be the type of " +
954                        "list/tuple.")
955
956    def create_function_pass(pass_arg):
957        """Internal function that creates a function pass"""
958        fname = name if name else pass_arg.__name__
959        info = PassInfo(opt_level, fname, required)
960        if inspect.isclass(pass_arg):
961            return _wrap_class_function_pass(pass_arg, info)
962        if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
963            raise TypeError("pass_func must be a callable for Module pass")
964        return _transform.MakeFunctionPass(pass_arg, info)
965
966    if pass_func:
967        return create_function_pass(pass_func)
968    return create_function_pass
969
970
971@function_pass(opt_level=1)
972class ChangeBatch:
973    """
974    Change the batch size.
975
976    Parameters
977    ----------
978    data: Dict[relay.Var, int]
979      A dictionary of all the params to change.
980      The keys are all params, and the values are which dimension hold the batch.
981
982    batch_size: int
983      The batch size to change to.
984
985    Returns
986    -------
987    pass: FunctionPass
988      The pass.
989    """
990    def __init__(self, data, batch_size=16):
991        self.data = data
992        self.batch_size = batch_size
993
994    def transform_function(self, func, mod, ctx):
995        func = relay.Function(func.params, func.body, None, func.type_params, func.attrs)
996        change_batch = self
997        class ChangeBatchMutator(tvm.relay.ExprMutator):
998            def visit_var(self, var):
999                if var in change_batch.data:
1000                    ty = var.type_annotation
1001                    new_shape = list(ty.shape)
1002                    new_shape[change_batch.data[var]] = change_batch.batch_size
1003                    return relay.Var(var.name_hint, relay.TensorType(new_shape, ty.dtype))
1004                else:
1005                    return var
1006        return ChangeBatchMutator().visit(func)
1007