1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=no-else-return 18# pylint: disable=unidiomatic-typecheck 19""" 20This file contains the set of passes for Relay, which exposes an interface for 21configuring the passes and scripting them in Python. 22""" 23from tvm.ir import IRModule 24from tvm.relay import transform, build_module 25from tvm.runtime.ndarray import cpu 26 27from . import _ffi_api 28from .feature import Feature 29 30 31def context_analysis(mod, default_context): 32 """Analyze the device context information of each IR node in a Relay 33 program. 34 35 Parameters 36 ---------- 37 mod : tvm.IRModule 38 The input module. 39 40 default_context : tvm.runtime.TVMContext 41 The default context allocated to an IR node. 42 """ 43 return _ffi_api.ContextAnalysis(mod, default_context) 44 45 46def post_order_visit(expr, fvisit): 47 """Recursively visit the ir in post DFS order node, 48 apply fvisit. Each node is guaranteed to be visited 49 only once. 50 51 Parameters 52 ---------- 53 expr : tvm.relay.Expr 54 The input expression. 55 56 fvisit : function 57 The visitor function to be applied. 58 """ 59 return _ffi_api.post_order_visit(expr, fvisit) 60 61 62def well_formed(expr): 63 """Check that each Var is only bound once (well formed). 64 65 Parameters 66 ---------- 67 expr : tvm.relay.Expr 68 The input expression 69 70 Returns 71 ------- 72 well_form : bool 73 Whether the input expression is well formed 74 """ 75 return _ffi_api.well_formed(expr) 76 77 78def check_kind(t, mod=None): 79 """Check that the type is well kinded and return the kind. 80 For example, this mean type cannot has tensor of tensor, or is a tuple type 81 of 2 shapes. 82 83 Parameters 84 ---------- 85 t : tvm.relay.Type 86 The type to check 87 88 mod : Optional[tvm.IRModule] 89 The global module. 90 91 Returns 92 ------- 93 kind : Kind 94 the kind of t 95 96 Examples 97 -------- 98 .. code:: python 99 100 assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Shape)])) == Shape 101 assert check_kind(relay.TupleType([relay.TypeParam('tp1', relay.Kind.Type)])) == Type 102 """ 103 if mod is not None: 104 return _ffi_api.check_kind(t, mod) 105 else: 106 return _ffi_api.check_kind(t) 107 108 109def check_constant(expr): 110 """Check whether an expression is constant 111 112 Parameters 113 ---------- 114 expr : tvm.relay.Expr 115 The input expression 116 117 Returns 118 ------- 119 result : bool 120 Whether the expression is constant. 121 """ 122 return _ffi_api.check_constant(expr) 123 124 125def check_basic_block_normal_form(expr): 126 """Check whether an expression is in the basic block form 127 128 Parameters 129 ---------- 130 expr : tvm.relay.Expr 131 The input expression 132 133 Returns 134 ------- 135 result : bool 136 Whether the expression is in the basic block form. 137 """ 138 return _ffi_api.check_basic_block_normal_form(expr) 139 140 141def free_vars(expr): 142 """Get free Vars from expression expr in Post DFS order. 143 144 Parameters 145 ---------- 146 expr : tvm.relay.Expr 147 The input expression 148 149 Returns 150 ------- 151 free : List[tvm.relay.Var] 152 The list of free variables in post DFS order. 153 154 Note 155 ---- 156 The fact that Vars are post-DFS ordred are useful in 157 neural networks: usually this means weights of previous 158 are ordered first. 159 """ 160 return _ffi_api.free_vars(expr) 161 162 163def bound_vars(expr): 164 """Get bound vars from expression expr in post-DFS order. 165 166 Parameters 167 ---------- 168 expr : tvm.relay.Expr 169 The input expression 170 171 Returns 172 ------- 173 free : List[tvm.relay.Var] 174 The list of bound variables in post-DFS order. 175 """ 176 return _ffi_api.bound_vars(expr) 177 178 179def all_vars(expr): 180 """Get all vars from expression expr in post-DFS order. 181 182 Parameters 183 ---------- 184 expr : tvm.relay.Expr 185 The input expression 186 187 Returns 188 ------- 189 free : List[tvm.relay.Var] 190 The list of all variables in post-DFS order. 191 """ 192 return _ffi_api.all_vars(expr) 193 194 195def free_type_vars(expr, mod=None): 196 """Get free type variables from expression/type e 197 198 Parameters 199 ---------- 200 expr : Union[tvm.relay.Expr,tvm.relay.Type] 201 The input expression/type 202 203 mod : Optional[tvm.IRModule] 204 The global module 205 206 Returns 207 ------- 208 free : List[tvm.relay.TypeVar] 209 The list of free type variables in post-DFS order 210 """ 211 use_mod = mod if mod is not None else IRModule() 212 return _ffi_api.free_type_vars(expr, use_mod) 213 214 215def bound_type_vars(expr, mod=None): 216 """Get bound type variables from expression/type e 217 218 Parameters 219 ---------- 220 expr : Union[tvm.relay.Expr,tvm.relay.Type] 221 The input expression/type 222 223 mod : Optional[tvm.IRModule] 224 The global module 225 226 Returns 227 ------- 228 free : List[tvm.relay.TypeVar] 229 The list of bound type variables in post-DFS order 230 """ 231 use_mod = mod if mod is not None else IRModule() 232 return _ffi_api.bound_type_vars(expr, use_mod) 233 234 235def all_type_vars(expr, mod=None): 236 """Get all type variables from expression/type e 237 238 Parameters 239 ---------- 240 expr : Union[tvm.relay.Expr,tvm.relay.Type] 241 The input expression/type 242 243 mod : Optional[tvm.IRModule] 244 The global module 245 246 Returns 247 ------- 248 free : List[tvm.relay.TypeVar] 249 The list of all type variables in post-DFS order 250 """ 251 use_mod = mod if mod is not None else IRModule() 252 return _ffi_api.all_type_vars(expr, use_mod) 253 254 255def all_dtypes(expr): 256 """Collect set of all data types used in `expr`. 257 258 Parameters 259 ---------- 260 expr : tvm.relay.Expr 261 The input expression 262 263 Returns 264 ------- 265 ret : Set[String] 266 Set of data types used in the expression (e.g., `{'int8', 'int32'}`) 267 """ 268 return set(_ffi_api.all_dtypes(expr)) 269 270 271def collect_device_info(expr): 272 """Collect the device allocation map for the given expression. The device 273 ids are propagated from the `device_copy` operators. 274 275 Parameters 276 ---------- 277 expr : tvm.relay.Expr 278 The input expression. 279 280 Returns 281 ------- 282 ret : Dict[tvm.relay.ir.expr, int] 283 A dictionary mapping tvm.relay.Expr to device type. 284 """ 285 return _ffi_api.CollectDeviceInfo(expr) 286 287 288def collect_device_annotation_ops(expr): 289 """Collect the device annotation ops for the given expression. 290 291 Parameters 292 ---------- 293 expr : tvm.relay.Expr 294 The input expression. 295 296 Returns 297 ------- 298 ret : Dict[tvm.relay.Expr, int] 299 A dictionary mapping tvm.relay.Expr to device type where the keys are 300 annotation expressions. 301 """ 302 return _ffi_api.CollectDeviceAnnotationOps(expr) 303 304 305def get_total_mac_number(expr): 306 """ 307 Count the number of MACs (multiply-accumulate) of a model 308 309 Parameters 310 ---------- 311 expr : tvm.relay.Expr 312 The input expression. 313 314 Returns 315 ------- 316 result : int64 317 The number of MACs (multiply-accumulate) of a model 318 """ 319 return _ffi_api.GetTotalMacNumber(expr) 320 321 322def unmatched_cases(match, mod=None): 323 """ 324 Finds cases that the match expression does not catch, if any. 325 326 Parameters 327 ---------- 328 match : tvm.relay.Match 329 The match expression 330 331 mod : Optional[tvm.IRModule] 332 The module (defaults to an empty module) 333 334 Returns 335 ------- 336 missing_patterns : [tvm.relay.Pattern] 337 Patterns that the match expression does not catch. 338 """ 339 return _ffi_api.unmatched_cases(match, mod) 340 341 342def detect_feature(a, b=None): 343 """ 344 Detect the feature used in a relay program. 345 346 Parameters 347 ---------- 348 a : Union[tvm.relay.Expr, tvm.IRModule] 349 The input expression or module. 350 351 b : Optional[Union[tvm.relay.Expr, tvm.IRModule]] 352 The input expression or module. 353 The two arguments cannot both be expression or module. 354 355 Returns 356 ------- 357 features : Set[Feature] 358 Features used in the program. 359 """ 360 if isinstance(a, IRModule): 361 a, b = b, a 362 return {Feature(int(x)) for x in _ffi_api.detect_feature(a, b)} 363 364 365def extract_fused_functions(mod): 366 """Pass to extract IRModule of only fused primitive functions. 367 368 The ExtractFusedFunctions pass invokes SimplifyInference, FuseOps(3), 369 and ExtractFusedFunctions in that order 370 371 Parameters 372 ---------- 373 mod : tvm.relay.IRModule 374 375 Returns 376 ------- 377 ret : Dict[int, tvm.relay.function.Function] 378 A module containing only fused primitive functions 379 """ 380 ret_mod = _ffi_api.ExtractFusedFunctions()(mod) 381 ret = {} 382 for hash_, func in ret_mod.functions.items(): 383 ret[hash_] = func 384 return ret 385 386 387def search_fc_transpose(expr): 388 """Search fc weight name in the patten: y = nn.dense(x, transpose(w, [1, 0])) 389 390 This function is used in the data_dep_optimization.simplify_fc_transpose method 391 392 Parameters 393 ---------- 394 expr : tvm.relay.Expr 395 396 Returns 397 ------- 398 ret : Array[String] 399 Array of weight variable name in pattern y = nn.dense(x, transpose(w, [1, 0])) 400 """ 401 ret = _ffi_api.search_fc_transpose(expr) 402 return ret 403 404 405def get_calibration_data(mod, data): 406 """Get the calibration data of a given relay graph 407 408 This pass uses the graph runtime to get the calibration data of a module, which 409 includes the input and output values of each function. The returned data uses 410 the GlobalVar of each function as a key. Users can further access the inputs and 411 outputs by using `inputs` or `outputs` as the key. 412 413 Following are some limitations: 414 1. The input module (graph) cannot have control flows. 415 2. The input arguments of each function cannot be tuples (outputs can be tuples). 416 3. We only handle top-level functions (i.e., nested function is not handled). 417 4. We only handle functions with `Compiler` attribute being set. 418 419 Parameters 420 ---------- 421 mod : tvm.IRModule 422 The input module for collecting the calibration data 423 424 data : Dict[str, NDArray] 425 The input data for running the module 426 427 Returns 428 ------- 429 data : Dict[tvm.relay.GlobalVar, Dict[str, NDArray]] 430 """ 431 output_map = _ffi_api.get_calibrate_output_map(mod) 432 433 mod = _ffi_api.get_calibrate_module(mod) 434 mod = transform.Inline()(mod) 435 436 ref_ex = build_module.create_executor("graph", mod=mod, ctx=cpu(0)) 437 ref_res = ref_ex.evaluate()(**data) 438 439 calib_data = {} 440 for gvar, indices in output_map.items(): 441 offset = int(indices[0]) 442 in_len = int(indices[1]) 443 out_len = int(indices[2]) 444 value = { 445 "inputs": ref_res[offset : offset + in_len], 446 "outputs": ref_res[offset + in_len : offset + in_len + out_len], 447 } 448 calib_data[gvar] = value 449 450 return calib_data 451