1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name,arguments-differ,no-else-return,unused-argument,missing-docstring 18""" 19Relay pass transformation infrastructure. 20""" 21import types 22import inspect 23import functools 24 25import tvm 26from tvm._ffi.runtime_ctypes import TVMContext 27from tvm import relay 28from . import _transform 29from .base import RelayNode, register_relay_node 30from .. import nd as _nd 31 32 33@register_relay_node 34class PassInfo(RelayNode): 35 """The class contains the meta data required by a pass. It is the 36 container of information needed by running an optimization or analysis. 37 This class can be extended by adding new members when more meta data is 38 needed. 39 40 Parameters 41 ---------- 42 opt_level : int 43 The optimization level of this pass. 44 45 name : str 46 The pass name. 47 48 required : List[str] 49 The list of passes that are required by a certain pass. 50 """ 51 52 def __init__(self, opt_level, name, required=None): 53 self.__init_handle_by_constructor__( 54 _transform.PassInfo, opt_level, name, required) 55 56 57@register_relay_node 58class PassContext(RelayNode): 59 """The basis where a Relay optimization/analysis runs on. 60 Each pass context contains a number of auxiliary information that is used 61 to help an optimization pass. Such information includes the error reporter 62 to record the errors of during the optimization, etc. 63 64 opt_level : Optional[int] 65 The optimization level of this pass. 66 67 fallback_device : Optional[Union[int, str, TVMContext]] 68 The fallback device type. It is also used as the default device for 69 operators that are not annotated during heterogeneous execution. 70 71 required_pass : Optional[Union[List[str], Set[str], Tuple[str]]] 72 The list of passes that are required by a certain pass. 73 74 disabled_pass : Optional[Union[List[str], Set[str], Tuple[str]]] 75 The list of passes that are disabled. 76 """ 77 def __init__(self, 78 opt_level=2, 79 fallback_device=_nd.cpu(), 80 required_pass=None, 81 disabled_pass=None): 82 if isinstance(fallback_device, str): 83 fallback_device = _nd.context(fallback_device).device_type 84 elif isinstance(fallback_device, TVMContext): 85 fallback_device = fallback_device.device_type 86 if not isinstance(fallback_device, int): 87 raise TypeError("required_pass is expected to be the type of " + 88 "int/str/TVMContext.") 89 90 required = list(required_pass) if required_pass else [] 91 if not isinstance(required, (list, tuple)): 92 raise TypeError("required_pass is expected to be the type of " + 93 "list/tuple/set.") 94 95 disabled = list(disabled_pass) if disabled_pass else [] 96 if not isinstance(disabled, (list, tuple)): 97 raise TypeError("disabled_pass is expected to be the type of " + 98 "list/tuple/set.") 99 100 self.__init_handle_by_constructor__(_transform.PassContext, opt_level, 101 fallback_device, required, 102 disabled) 103 104 def __enter__(self): 105 _transform.EnterPassContext(self) 106 return self 107 108 def __exit__(self, ptype, value, trace): 109 _transform.ExitPassContext(self) 110 111 @staticmethod 112 def current(): 113 """Return the current pass context.""" 114 return _transform.GetCurrentPassContext() 115 116 117def build_config(opt_level=2, 118 fallback_device=_nd.cpu(), 119 required_pass=None, 120 disabled_pass=None): 121 """Configure the build behavior by setting config variables. 122 123 Parameters 124 ---------- 125 opt_level: int, optional 126 Optimization level. The optimization pass name and level are as the 127 following: 128 129 .. code-block:: python 130 131 OPT_PASS_LEVEL = { 132 "SimplifyInference": 0, 133 "OpFusion": 1, 134 "FoldConstant": 2, 135 "FoldScaleAxis": 3, 136 "AlterOpLayout": 3, 137 "CanonicalizeOps": 3, 138 "CanonicalizeCast": 3, 139 "EliminateCommonSubexpr": 3, 140 "CombineParallelConv2D": 4, 141 "CombineParallelDense": 4 142 } 143 144 fallback_device : int, str, or tvm.TVMContext, optional 145 The fallback device. It is also used as the default device for 146 operators without specified device during heterogeneous execution. 147 148 required_pass: set of str, optional 149 Optimization passes that are required regardless of optimization level. 150 151 disabled_pass: set of str, optional 152 Optimization passes to be disabled during optimization. 153 154 Returns 155 ------- 156 pass_context: PassContext 157 The pass context for optimizations. 158 """ 159 return PassContext(opt_level, fallback_device, required_pass, 160 disabled_pass) 161 162 163@register_relay_node 164class Pass(RelayNode): 165 """The base class of all passes. All methods here are just simple wrappers 166 that are implemented in the backend. They are defined for users to 167 conveniently interact with the base class. 168 """ 169 170 @property 171 def info(self): 172 """Get the pass meta.""" 173 return _transform.Info(self) 174 175 def __call__(self, mod): 176 """Execute the pass. Note that for sequential pass, the dependency among 177 different passes will be resolved in the backend. 178 179 Parameters 180 ---------- 181 mod : tvm.relay.Module 182 The module that a certain optimization is performed on. 183 184 Returns 185 ------- 186 mod : tvm.relay.Module 187 The updated module after applying this pass. 188 """ 189 return _transform.RunPass(self, mod) 190 191 192@register_relay_node 193class ModulePass(Pass): 194 """A pass that works on tvm.relay.Module. Users don't need to interact with 195 this class directly. Instead, a module pass should be created through 196 `module_pass`, because the design of the `module_pass` API is flexible 197 enough to handle the creation of a module pass in different manners. In 198 addition, all members of a module pass can be accessed from the base class. 199 The same rule applies to FunctionPass as well. 200 """ 201 202 203@register_relay_node 204class FunctionPass(Pass): 205 """A pass that works on each tvm.relay.Function in a module. A function 206 pass class should be created through `function_pass`. 207 """ 208 209 210@register_relay_node 211class Sequential(Pass): 212 """A pass that works on a sequence of pass objects. Multiple passes can be 213 executed sequentially using this class. 214 215 Some typical usage of the sequential pass are: 216 1. Users provide a list of passes for optimization. 217 2. Only an optimization level is provided so that the backend system has 218 to glob all passes at this level and below to perform the optimizations. 219 Note that users can also provide a series of passes that they don't want to 220 apply when running a sequential pass. Pass dependency will be resolved in 221 the backend as well. 222 223 Parameters 224 ---------- 225 passes : Optional[List[Pass]] 226 A sequence of passes candidate for optimization. 227 228 opt_level : Optional[int] 229 The optimization level of this sequential pass. 230 231 name : Optional[str] 232 The name of the sequential pass. 233 234 required : Optional[List[str]] 235 The list of passes that the sequential pass is dependent on. 236 """ 237 238 def __init__(self, 239 passes=None, 240 opt_level=2, 241 name="sequential", 242 required=None): 243 passes = passes if passes else [] 244 if not isinstance(passes, (list, tuple)): 245 raise TypeError("passes must be a list of Pass objects.") 246 247 required = required if required else [] 248 if not isinstance(required, (list, tuple)): 249 raise TypeError("Required is expected to be the type of list/tuple.") 250 251 self.__init_handle_by_constructor__(_transform.Sequential, 252 passes, opt_level, name, required) 253 254 255def InferType(): 256 """Infer the type of an expr. 257 258 Returns 259 ------- 260 ret : tvm.relay.Pass 261 The registered type inference pass. 262 """ 263 return _transform.InferType() 264 265 266def FoldScaleAxis(): 267 """Fold the scaling of axis into weights of conv2d/dense. This pass will 268 invoke both forward and backward scale folding. 269 270 Returns 271 ------- 272 ret : tvm.relay.Pass 273 The registered pass to fold expressions. 274 275 Note 276 ---- 277 Internally, we will call backward_fold_scale_axis before using 278 forward_fold_scale_axis as backward folding targets the common conv->bn 279 pattern. 280 """ 281 return _transform.FoldScaleAxis() 282 283 284def BackwardFoldScaleAxis(): 285 """Backward fold axis scaling into weights of conv2d/dense. 286 287 Returns 288 ------- 289 ret : tvm.relay.Pass 290 The registered pass to backward fold expressions. 291 292 Note 293 ---- 294 It is recommended to call backward_fold_scale_axis 295 before using forward_fold_scale_axis as backward folding targets the common 296 conv->bn pattern. 297 """ 298 return _transform.BackwardFoldScaleAxis() 299 300def RemoveUnusedFunctions(entry_functions=None): 301 """Remove unused global relay functions in a relay module. 302 303 Parameters 304 ---------- 305 entry_functions: list[string] 306 The set of entry functions to start from. 307 308 Returns 309 ------- 310 ret : tvm.relay.Pass 311 The registered pass to remove unused functions. 312 """ 313 if entry_functions is None: 314 entry_functions = ['main'] 315 return _transform.RemoveUnusedFunctions(entry_functions) 316 317def ForwardFoldScaleAxis(): 318 """Fold the scaling of axis into weights of conv2d/dense. 319 320 Returns 321 ------- 322 ret : tvm.relay.Pass 323 The registered pass to forward fold expressions. 324 325 Note 326 ---- 327 It is recommended to call backward_fold_scale_axis 328 before using forward_fold_scale_axis, as backward folding targets the 329 common conv->bn pattern. 330 """ 331 return _transform.ForwardFoldScaleAxis() 332 333 334def SimplifyInference(): 335 """Simplify the data-flow graph for inference phase. An simplified expression 336 which is semantically equal to the input expression will be returned. 337 338 Returns 339 ------- 340 ret: tvm.relay.Pass 341 The registered to perform operator simplification. 342 """ 343 return _transform.SimplifyInference() 344 345 346def CanonicalizeOps(): 347 """Canonicalize special operators to basic operators. 348 This can simplify followed analysis, e.g. expanding bias_add to 349 expand_dims and broadcast_add. 350 351 Returns 352 ------- 353 ret: tvm.relay.Pass 354 The registered pass performing the canonicalization. 355 """ 356 return _transform.CanonicalizeOps() 357 358 359def DeadCodeElimination(inline_once=False): 360 """Remove expressions that do not have any users (dead code). 361 362 Parameters 363 ---------- 364 inline_once: Optional[Bool] 365 Whether to inline binding that occurs only once. 366 367 Returns 368 ------- 369 ret: tvm.relay.Pass 370 The registered pass that eliminates the dead code in a Relay program. 371 """ 372 return _transform.DeadCodeElimination(inline_once) 373 374 375def FoldConstant(): 376 """Fold the constant expressions in a Relay program. 377 378 Returns 379 ------- 380 ret : tvm.relay.Pass 381 The registered pass for constant folding. 382 """ 383 return _transform.FoldConstant() 384 385 386def FuseOps(fuse_opt_level=-1): 387 """Fuse operators in an expr to a larger operator according to some rules. 388 389 Parameters 390 ---------- 391 fuse_opt_level : int 392 The level of fuse optimization. -1 indicates that the level will be 393 inferred from pass context. 394 395 Returns 396 ------- 397 ret : tvm.relay.Pass 398 The registered pass for operator fusion. 399 """ 400 return _transform.FuseOps(fuse_opt_level) 401 402 403def CombineParallelConv2D(min_num_branches=3): 404 """Combine multiple conv2d operators into one. 405 406 Parameters 407 ---------- 408 min_num_branches : int 409 The minimum number of required parallel branches for performing this 410 optimization. 411 412 Returns 413 ------- 414 ret: tvm.relay.Pass 415 The registered pass that combines parallel conv2d operators. 416 """ 417 return _transform.CombineParallelConv2D(min_num_branches) 418 419 420def CombineParallelDense(min_num_branches=3): 421 """Combine multiple dense operators into one. For example: 422 423 data 424 / \ 425 dense (2,2) dense (2,2) 426 | | 427 elemwise/bcast (2,2) elemwise/bcast (2,2) 428 429 Would become: 430 431 data 432 | 433 batch_matmul+elemwise/bcast (2,2,2) 434 435 Parameters 436 ---------- 437 min_num_branches : int 438 The minimum number of required parallel branches for performing this 439 optimization. 440 441 Returns 442 ------- 443 ret: tvm.relay.Pass 444 The registered pass that combines parallel dense operators. 445 """ 446 return _transform.CombineParallelDense(min_num_branches) 447 448 449def AlterOpLayout(): 450 """Alternate the layouts of operators or replace primitive operators with 451 other expressions. 452 This pass can be used for computing convolution in custom layouts or 453 other general weight pre-transformation. 454 455 Returns 456 ------- 457 ret : tvm.relay.Pass 458 The registered pass that alters the layout of operators. 459 """ 460 return _transform.AlterOpLayout() 461 462 463def Legalize(legalize_map_attr_name="FTVMLegalize"): 464 """Legalizes an expression with another expression. 465 This pass can be used to replace an expr with another expr for target 466 dependent optimizations. For example, one expr, though semnatically 467 equivalent to the other, can have better performance on a target. This pass 468 can be used to legalize the expr in a target-dependent manner. 469 470 Parameters 471 ---------- 472 legalize_map_attr_name : str 473 The Op's attr name which corresponds to the legalize rule function. 474 475 Returns 476 ------- 477 ret : tvm.relay.Pass 478 The registered pass that rewrites an expr. 479 """ 480 return _transform.Legalize(legalize_map_attr_name) 481 482 483def RewriteAnnotatedOps(fallback_device): 484 """Rewrite the annotated program where annotation operators, e.g. 485 `on_deivce`, mark which device an expression should be scheduled to. 486 This pass helps heterogeneous execution where different operators may need 487 to be allocated on various devices. 488 489 Parameters 490 ---------- 491 fallback_device : int 492 The fallback device type. It is also used as the default device for 493 operators with no annotated device. 494 495 Returns 496 ------- 497 ret: tvm.relay.Pass 498 The registered pass that rewrites an expression with annotated 499 `on_device` operators. 500 """ 501 return _transform.RewriteDeviceAnnotation(fallback_device) 502 503 504def ToANormalForm(): 505 """Turn Graph Normal Form expression into A Normal Form Expression. 506 The scope of the root expression is the global scope. 507 The scope of any non root expression is the least common ancestor of all it's scope. 508 Values are ordered by post-DFS order in each scope. 509 510 Returns 511 ------- 512 ret: Union[tvm.relay.Pass, tvm.relay.Expr] 513 The registered pass that transforms an expression into A Normal Form. 514 """ 515 return _transform.ToANormalForm() 516 517 518def ToCPS(expr, mod=None): 519 """ 520 Turn expression into continuation passing style(CPS). 521 522 Every intermediate compute will be passed to a continuation. 523 524 Returns 525 ------- 526 result: tvm.relay.Pass 527 The registered pass that transforms an expression into CPS. 528 """ 529 return _transform.to_cps(expr, mod) 530 531 532def EtaExpand(expand_constructor=False, expand_global_var=False): 533 """Add abstraction over a constructor or global variable bound to a function 534 535 Parameters 536 ---------- 537 expand_constructor: bool 538 Whether to expand constructors. 539 540 expand_global_var: bool 541 Whether to expand global variables. 542 543 Returns 544 ------- 545 ret: tvm.relay.Pass 546 The registered pass that eta expands an expression. 547 """ 548 return _transform.EtaExpand(expand_constructor, expand_global_var) 549 550 551def ToGraphNormalForm(): 552 """Turn a Relay program in A Normal Form into Graph Normal Form 553 554 Returns 555 ------- 556 ret : tvm.relay.Pass 557 The registered pass that transforms an expression into Graph Normal Form. 558 """ 559 return _transform.ToGraphNormalForm() 560 561 562def EliminateCommonSubexpr(fskip=None): 563 """Eliminate common subexpressions. 564 565 Parameters 566 ---------- 567 fskip: Callable 568 The callback function that decides whether an expression should be 569 skipped. 570 571 Returns 572 ------- 573 ret : tvm.relay.Pass 574 The registered pass that eliminates common subexpressions. 575 """ 576 return _transform.EliminateCommonSubexpr(fskip) 577 578 579def PartialEvaluate(): 580 """Evaluate the static fragment of the code. 581 582 Note 583 ---- 584 This transformation could be either `Module -> Module` or `Expr -> Expr`. 585 It will directly transform the input expression to a new one if the target 586 expression is provided. Otherwise, it will rely on the pass manager to 587 carry out transformation. 588 589 Returns 590 ------- 591 ret: tvm.relay.Pass 592 The registered pass that performs partial evaluation on an expression. 593 """ 594 return _transform.PartialEvaluate() 595 596 597def CanonicalizeCast(): 598 """ 599 Canonicalize cast expressions to make operator fusion more efficient. 600 601 Returns 602 ------- 603 ret : tvm.relay.Pass 604 The registered pass that canonicalizes cast expression. 605 """ 606 return _transform.CanonicalizeCast() 607 608 609def LambdaLift(): 610 """ 611 Lift the closure to global function. 612 613 Returns 614 ------- 615 ret : tvm.relay.Pass 616 The registered pass that lifts the lambda function. 617 """ 618 return _transform.LambdaLift() 619 620 621def PrintIR(show_meta_data=True): 622 """ 623 Print the IR for a module to help debugging. 624 625 Parameters 626 ---------- 627 show_meta_data : bool 628 A boolean flag to indicate if meta data should be printed. 629 630 Returns 631 ------- 632 ret : tvm.relay.Pass 633 The registered pass that prints the module IR. 634 """ 635 return _transform.PrintIR(show_meta_data) 636 637 638def gradient(expr, mod=None, mode='higher_order'): 639 """ 640 Transform the input function, 641 returning a function that calculate the original result, 642 paired with gradient of the input. 643 644 Parameters 645 ---------- 646 expr : tvm.relay.Expr 647 The input expression, which is a Function or a GlobalVar. 648 649 mod : Optional[tvm.relay.Module] 650 651 mode : Optional[String] 652 The mode of the automatic differentiation algorithm. 653 'first_order' only works on first order code, but will not produce 654 reference nor closure. 655 'higher_order' works on all code using reference and closure. 656 657 Returns 658 ------- 659 expr : tvm.relay.Expr 660 The transformed expression. 661 """ 662 if mode == 'first_order': 663 return _transform.first_order_gradient(expr, mod) 664 if mode == 'higher_order': 665 return _transform.gradient(expr, mod) 666 raise Exception('unknown mode') 667 668 669def to_cps(func, mod=None): 670 """ 671 Turn expression into CPS expression. 672 673 Every intermediate compute will be passed to a continuation. 674 675 Parameters 676 ---------- 677 func: tvm.relay.Function 678 The input function. 679 680 mod: Optional[tvm.relay.Module] 681 The global module. 682 683 Returns 684 ------- 685 result: tvm.relay.Function 686 The output function. 687 """ 688 return _transform.to_cps(func, mod) 689 690 691def un_cps(func): 692 """ 693 Turn an cps function into a Function without the continuation argument. 694 695 Note that this will not give the exact same interface as before cps: 696 If the input/output is higher order, they will still be in cps form. 697 698 Parameters 699 ---------- 700 func: tvm.relay.Function 701 The input function 702 703 Returns 704 ------- 705 result: tvm.relay.Function 706 The output function 707 """ 708 return _transform.un_cps(func) 709 710 711def _wrap_class_module_pass(pass_cls, pass_info): 712 """Wrap a python class as function pass""" 713 class PyModulePass(ModulePass): 714 """Internal wrapper class to create a class instance.""" 715 def __init__(self, *args, **kwargs): 716 # initialize handle in cass pass_cls creation failed.fg 717 self.handle = None 718 inst = pass_cls(*args, **kwargs) 719 # it is important not to capture self to 720 # avoid a cyclic dependency 721 def _pass_func(mod, ctx): 722 return inst.transform_module(mod, ctx) 723 self.__init_handle_by_constructor__( 724 _transform.MakeModulePass, _pass_func, pass_info) 725 self._inst = inst 726 727 def __getattr__(self, name): 728 # fall back to instance attribute if there is not any 729 return self._inst.__getattribute__(name) 730 731 functools.update_wrapper(PyModulePass.__init__, pass_cls.__init__) 732 PyModulePass.__name__ = pass_cls.__name__ 733 PyModulePass.__doc__ = pass_cls.__doc__ 734 PyModulePass.__module__ = pass_cls.__module__ 735 return PyModulePass 736 737 738def module_pass(pass_func=None, opt_level=None, name=None, required=None): 739 """Decorate a module pass. 740 741 This function returns a callback when pass_func is provided. 742 Otherwise, it serves a decorator function. 743 744 pass_func can also be a class type with a method transform_module. 745 This function will create a decorated ModulePass using transform_module 746 as the pass function. 747 748 Parameters 749 ---------- 750 pass_func : Optional[Callable[(Module, PassContext) ->Module]] 751 The transformation function or class. 752 753 opt_level : int 754 The optimization level of this module pass. 755 756 name : Optional[str] 757 The name of the module pass. The name could be empty. In this case, the 758 name of the optimization function will be used as the pass name. 759 760 required : Optional[List[str]] 761 The list of passes that the module pass is dependent on. 762 763 Returns 764 ------- 765 create_module_pass : Union[Callable, ModulePass] 766 A decorator will be returned if pass_func is not provided, 767 otherwise return the decorated result. 768 The returned decorator has two behaviors depending on the input: 769 A new ModulePass will be returned when we decorate a pass function. 770 A new ModulePass class will be returned when we decorate a class type. 771 772 Examples 773 -------- 774 The following code block decorates a module pass class. 775 776 .. code-block:: python 777 778 @relay.transform.module_pass 779 class CustomPipeline: 780 def __init__(self, enable_fold): 781 self.enable_fold = enable_fold 782 self.cse = relay.transform.EliminateCommonSubexpr() 783 self.const_fold = relay.transform.FoldConstant() 784 785 def transform_module(self, mod, ctx): 786 mod = self.cse(mod, ctx) 787 if self.enable_fold: 788 mod = self.const_fold(mod, ctx) 789 return mod 790 791 # create an instance of customized pipeline 792 pipeline = CustomPipeline(enable_fold=False) 793 assert isinstance(pipeline, transform.ModulePass) 794 # run the pipeline. 795 output_module = pipeline(input_module) 796 797 The following code creates a module pass by decorating 798 a user defined transform function. 799 800 .. code-block:: python 801 802 @relay.transform.module_pass(opt_level=2) 803 def transform(mod, ctx): 804 tp = relay.TensorType((10,), "float32") 805 x = relay.var("x", tp) 806 gv = relay.GlobalVar("var") 807 func = relay.Function([x], relay.abs(x)) 808 new_mod = relay.Module({gv: func}) 809 new_mod.update(mod) 810 return new_mod 811 812 module_pass = transform 813 assert isinstance(module_pass, transform.ModulePass) 814 assert module_pass.info.opt_level == 2 815 816 # Given a module m, the optimization could be invoked as the follwoing: 817 updated_mod = module_pass(m) 818 # Now a function abs should be added to the module m. 819 """ 820 if opt_level is None: 821 raise ValueError("Please provide opt_level for the module pass.") 822 823 required = required if required else [] 824 if not isinstance(required, (list, tuple)): 825 raise TypeError("Required is expected to be the type of " + 826 "list/tuple.") 827 828 def create_module_pass(pass_arg): 829 """Internal function that creates a module pass""" 830 fname = name if name else pass_arg.__name__ 831 info = PassInfo(opt_level, fname, required) 832 if inspect.isclass(pass_arg): 833 return _wrap_class_module_pass(pass_arg, info) 834 if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): 835 raise TypeError("pass_func must be a callable for Module pass") 836 return _transform.MakeModulePass(pass_arg, info) 837 838 if pass_func: 839 return create_module_pass(pass_func) 840 return create_module_pass 841 842 843def _wrap_class_function_pass(pass_cls, pass_info): 844 """Wrap a python class as function pass""" 845 class PyFunctionPass(FunctionPass): 846 """Internal wrapper class to create a class instance.""" 847 def __init__(self, *args, **kwargs): 848 # initialize handle in cass pass_cls creation failed.fg 849 self.handle = None 850 inst = pass_cls(*args, **kwargs) 851 # it is important not to capture self to 852 # avoid a cyclic dependency 853 def _pass_func(func, mod, ctx): 854 return inst.transform_function(func, mod, ctx) 855 self.__init_handle_by_constructor__( 856 _transform.MakeFunctionPass, _pass_func, pass_info) 857 self._inst = inst 858 859 def __getattr__(self, name): 860 # fall back to instance attribute if there is not any 861 return self._inst.__getattribute__(name) 862 863 functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__) 864 PyFunctionPass.__name__ = pass_cls.__name__ 865 PyFunctionPass.__doc__ = pass_cls.__doc__ 866 PyFunctionPass.__module__ = pass_cls.__module__ 867 return PyFunctionPass 868 869 870def function_pass(pass_func=None, opt_level=None, name=None, required=None): 871 """Decorate a function pass. 872 873 This function returns a callback when pass_func 874 is provided. Otherwise, it returns the created function pass using the 875 given optimization function. 876 877 Parameters 878 ---------- 879 pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]] 880 The transformation function or class. 881 882 opt_level : int 883 The optimization level of this module pass. 884 885 name : Optional[str] 886 The name of the function pass. The name could be empty. In this case, the 887 name of the optimization function will be used as the pass name. 888 889 required : Optional[List[str]] 890 The list of passes that the module pass is dependent on. 891 892 Returns 893 ------- 894 create_function_pass : Union[Callable, FunctionPass] 895 896 A decorator will be returned if pass_func is not provided, 897 otherwise return the decorated result. 898 The returned decorator has two behaviors depending on the input: 899 A new FunctionPass will be returned when we decorate a pass function. 900 A new FunctionPass class will be returned when we decorate a class type. 901 902 Examples 903 -------- 904 The following code block decorates a function pass class. 905 906 .. code-block:: python 907 908 @relay.transform.function_pass(opt_level=1) 909 class TestReplaceFunc: 910 def __init__(self, new_func): 911 self.new_func = new_func 912 913 def transform_function(self, func, mod, ctx): 914 # just for demo purposes 915 # transform func to new_func 916 return self.new_func 917 918 x = relay.var("x", shape=(10, 20)) 919 f1 = relay.Function([x], x) 920 f2 = relay.Function([x], relay.log(x)) 921 # fpass is now a special pass that replaces every 922 # function to f1 923 fpass = TestReplaceFunc(f1) 924 # now every function in input_mod is replaced by f1 925 res_mod = fpass(input_mod) 926 927 928 The following code creates a function pass by decorating 929 a user defined transform function. 930 931 .. code-block:: python 932 933 @relay.transform.function_pass(opt_level=2) 934 def transform(func, mod, ctx): 935 # my transformations here. 936 return func 937 938 function_pass = transform 939 assert isinstance(function_pass, transform.FunctionPass) 940 assert function_pass.info.opt_level == 2 941 942 # Given a module m, the optimization could be invoked as the follwoing: 943 updated_mod = function_pass(m) 944 # Now constant folding should have been applied to every function in 945 # the provided module m. And the updated module will be returned. 946 """ 947 948 if opt_level is None: 949 raise ValueError("Please provide opt_level for the funtion pass.") 950 951 required = required if required else [] 952 if not isinstance(required, (list, tuple)): 953 raise TypeError("Required is expected to be the type of " + 954 "list/tuple.") 955 956 def create_function_pass(pass_arg): 957 """Internal function that creates a function pass""" 958 fname = name if name else pass_arg.__name__ 959 info = PassInfo(opt_level, fname, required) 960 if inspect.isclass(pass_arg): 961 return _wrap_class_function_pass(pass_arg, info) 962 if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): 963 raise TypeError("pass_func must be a callable for Module pass") 964 return _transform.MakeFunctionPass(pass_arg, info) 965 966 if pass_func: 967 return create_function_pass(pass_func) 968 return create_function_pass 969 970 971@function_pass(opt_level=1) 972class ChangeBatch: 973 """ 974 Change the batch size. 975 976 Parameters 977 ---------- 978 data: Dict[relay.Var, int] 979 A dictionary of all the params to change. 980 The keys are all params, and the values are which dimension hold the batch. 981 982 batch_size: int 983 The batch size to change to. 984 985 Returns 986 ------- 987 pass: FunctionPass 988 The pass. 989 """ 990 def __init__(self, data, batch_size=16): 991 self.data = data 992 self.batch_size = batch_size 993 994 def transform_function(self, func, mod, ctx): 995 func = relay.Function(func.params, func.body, None, func.type_params, func.attrs) 996 change_batch = self 997 class ChangeBatchMutator(tvm.relay.ExprMutator): 998 def visit_var(self, var): 999 if var in change_batch.data: 1000 ty = var.type_annotation 1001 new_shape = list(ty.shape) 1002 new_shape[change_batch.data[var]] = change_batch.batch_size 1003 return relay.Var(var.name_hint, relay.TensorType(new_shape, ty.dtype)) 1004 else: 1005 return var 1006 return ChangeBatchMutator().visit(func) 1007