1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Utility for converting Relay code into a Python script with equivalent semantics""" 18import ast 19from ast import alias, Assign, Load, Name, NameConstant, Num, Return, Store, Str 20import re 21 22import tvm 23from tvm import relay 24from tvm.relay.adt import Pattern 25from tvm.relay.backend import compile_engine 26from tvm.relay.expr import Expr, GlobalVar, Var 27from tvm.relay.function import Function 28from tvm.relay.expr_functor import ExprFunctor 29 30OUTPUT_VAR_NAME = "_py_out" 31 32# corresponds to: 33# import numpy 34# import tvm 35# from tvm import relay 36# from tvm import nd 37# from tvm.runtime import import container as _container 38# from tvm.relay.backend.interpreter import RefValue, ConstructorValue 39PROLOGUE = [ 40 ast.Import([alias("numpy", None)]), 41 ast.Import([alias("tvm", None)]), 42 ast.ImportFrom("tvm", [alias("relay", None)], 0), 43 ast.ImportFrom("tvm", [alias("nd", None)], 0), 44 ast.ImportFrom("tvm.runtime", [alias("container", "_container")], 0), 45 ast.ImportFrom( 46 "tvm.relay.backend.interpreter", 47 [alias("RefValue", None), alias("ConstructorValue", None)], 48 0, 49 ), 50] 51 52 53class PythonConverter(ExprFunctor): 54 """Functor for translating Relay programs into Python ASTs.""" 55 56 def __init__(self, mod, target) -> None: 57 super().__init__() 58 self.mod = mod 59 self.tgt = target 60 self.engine = compile_engine.get() 61 self.fun_no = 0 62 self.var_no = 0 63 self.var_map = {} 64 65 def convert(self, prog: Expr): 66 """This method converts the passed Relay expression into a Python 67 AST object with equivalent semantics. 68 69 The Python AST can be executed using exec(); it can be turned 70 into text and inspected using astor. 71 """ 72 optimized = self.optimize(prog) 73 74 # start with conversion prelude (imports) and convert global defs 75 body = [] 76 body += PROLOGUE 77 body += self.convert_module() 78 79 prog_body, extra_defs = self.visit(optimized) 80 body += extra_defs 81 82 # we finally must assign the final expression to the output var 83 # so it can be read after running EXEC 84 body.append(Assign([Name(OUTPUT_VAR_NAME, Store())], prog_body)) 85 86 return ast.fix_missing_locations(ast.Module(body=body)) 87 88 def optimize(self, prog: Expr): 89 """Performs optimizations necessary to be able to generate code for prog.""" 90 # unwrap tuple wrappers (some op calls produce them) 91 unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper) else prog 92 assert relay.analysis.well_formed(unwrapped) 93 mod = self.mod.from_expr(unwrapped, self.mod.functions, self.mod.type_definitions) 94 95 # necessary pass: SimplifyInference (otherwise we can't generate code for some operators) 96 # and fusion (to get primitive functions) 97 opts = tvm.transform.Sequential( 98 [relay.transform.SimplifyInference(), relay.transform.FuseOps(fuse_opt_level=0)] 99 ) 100 mod = opts(mod) 101 optimized = mod["main"] 102 return optimized if isinstance(unwrapped, Function) else optimized.body 103 104 def sanitize(self, name: str) -> str: 105 """Removes any invalid characters (only underscores, numbers, and letters permitted) 106 from the given name. Since we append a number and underscore to var names anyway, 107 it doesn't matter if the name is the empty string.""" 108 return re.sub(r"\W", "", name) 109 110 def generate_var_name(self, name_hint: str) -> str: 111 """Generates a unique variable name starting from the hint.""" 112 name = "{}_var_{}".format(self.sanitize(name_hint), self.var_no) 113 self.var_no += 1 114 return name 115 116 def generate_function_name(self, name_hint: str) -> str: 117 """Generates a unique function name starting from the hint.""" 118 name = "{}_fun_{}".format(self.sanitize(name_hint), self.fun_no) 119 self.fun_no += 1 120 return name 121 122 def get_var_name(self, var: Expr) -> str: 123 """Returns the var name for the given Realy variable.""" 124 if var in self.var_map: 125 return self.var_map[var] 126 name = self.generate_var_name(var.name_hint) 127 self.var_map[var] = name 128 return name 129 130 def include_var(self, var: Expr, assign=False): 131 """Returns a variable AST node for the given Relay var depending on 132 whether it must appear in an assignment or not.""" 133 name = self.get_var_name(var) 134 return Name(name, Store() if assign else Load()) 135 136 def parse_name(self, name: str): 137 """Given the name of a Python method with dots (e.g., 'relay.var'), 138 returns an appropriate AST object corresponding to that name.""" 139 attributes = name.split(".") 140 ret = Name(attributes[0], Load()) 141 for i in range(len(attributes) - 1): 142 ret = ast.Attribute(ret, attributes[i + 1], Load()) 143 return ret 144 145 def parse_numpy_array(self, arr): 146 """Given a Numpy array, produces an appropriate Python array 147 or numerical literal representing its contents.""" 148 parse_single = lambda i: NameConstant(i) if isinstance(i, bool) else Num(i) 149 if arr.ndim == 0: 150 return parse_single(arr.item()) 151 if arr.ndim == 1: 152 return ast.List([parse_single(i.item()) for i in arr], Load()) 153 154 elts = [] 155 for row in arr: 156 elts.append(self.parse_numpy_array(row)) 157 return ast.List(elts, Load()) 158 159 def convert_fields(self, fields: [Expr]): 160 """Given a list of call args or tuple fields, converts 161 each and returns their ASTs and their defs lists (in order).""" 162 bodies = [] 163 defs = [] 164 for field in fields: 165 member_body, member_defs = self.visit(field) 166 bodies.append(member_body) 167 defs += member_defs 168 return (bodies, defs) 169 170 def convert_to_thunk(self, name_hint: str, expr: Expr): 171 """Wraps the passed expression in a thunk.""" 172 body, defs = self.visit(expr) 173 thunk_name = self.generate_function_name(name_hint) 174 thunk = self.create_def(thunk_name, [], defs + [Return(body)]) 175 return (thunk, thunk_name) 176 177 def convert_func_node(self, func: Function, name_var=None): 178 """Converts the given Relay function into a Python function, with 179 special for named functions (locally or globally)""" 180 if name_var is None: 181 func_name = self.generate_function_name("_anon_func") 182 if isinstance(name_var, GlobalVar): 183 func_name = str(name_var.name_hint) 184 if isinstance(name_var, Var): 185 func_name = self.get_var_name(name_var) 186 187 var_names = [self.get_var_name(var) for var in func.params] 188 body, defs = self.visit(func.body) 189 ret = self.create_def(func_name, var_names, defs + [Return(body)]) 190 return (ret, func_name) 191 192 def convert_module(self): 193 """Converts all the global functions defined in the module and returns 194 them as a list of definitions""" 195 defs = [] 196 for var, func in self.mod.functions.items(): 197 # optimize the definition so any operators used are lowered 198 opt_func = self.optimize(func) 199 try: 200 converted_func, _ = self.convert_func_node(opt_func, var) 201 defs.append(converted_func) 202 except TypeError: 203 # TODO(wweic): fix conversion for Any 204 pass 205 return defs 206 207 def create_call(self, func_name: str, arguments): 208 """Creates a simple function call.""" 209 return ast.Call(self.parse_name(func_name), arguments, []) 210 211 def create_def(self, func_name: str, arguments: [str], body): 212 """Wrapper over function definition AST node, whose constructor is inconvenient.""" 213 return ast.FunctionDef( 214 func_name, 215 ast.arguments( 216 [ast.arg(argument, None) for argument in arguments], None, [], [], None, [] 217 ), 218 body, 219 [], 220 None, 221 ) 222 223 def create_op_call(self, op: Function, relay_args, py_args): 224 """Lowers the passed primitive function, registers it in TVM's 225 global compiler, and produces a call to the lowered function in 226 the generated Python code.""" 227 228 # compile the function and register globally 229 cc_key = compile_engine.CCacheKey(op, self.tgt) 230 func_hash = tvm.ir.structural_hash(op) 231 op_name = "_lowered_op_{}".format(func_hash) 232 if not tvm.get_global_func(op_name, allow_missing=True): 233 jitted = self.engine.jit(cc_key, self.tgt) 234 tvm.register_func(op_name, jitted) 235 236 def convert_input(py_input, arg_type): 237 """Use the types of the function arguments to determine whether we expect 238 a tensor or tuple (returns list of inputs to the lowered op call)""" 239 # equivalent: input.data 240 if isinstance(arg_type, relay.TensorType): 241 return [py_input] 242 assert isinstance(arg_type, relay.TupleType) 243 # convert each input.fields[i] 244 ret = [] 245 for i in range(len(arg_type.fields)): 246 ret += convert_input( 247 ast.Subscript(py_input, ast.Index(Num(i)), Load()), arg_type.fields[i] 248 ) 249 return ret 250 251 def convert_output(ret_type): 252 """Use the function return type to produce auxiliary variables to store outputs. 253 Returns ([assignments of output vars], [extra arguments to pass to op call], 254 expression collecting output)""" 255 if isinstance(ret_type, relay.TensorType): 256 output_var_name = self.generate_var_name("_out") 257 output_var = Name(output_var_name, Load()) 258 shape = ast.Tuple([Num(dim) for dim in ret_type.concrete_shape], Load()) 259 # create a new NDArray of the right shape and dtype 260 assign_output = Assign( 261 [Name(output_var_name, Store())], 262 self.create_call( 263 "nd.array", [self.create_call("numpy.empty", [shape, Str(ret_type.dtype)])] 264 ), 265 ) 266 return ([assign_output], [output_var], output_var) 267 assert isinstance(ret_type, relay.TupleType) 268 assignments = [] 269 extra_args = [] 270 fields = [] 271 for t in ret_type.fields: 272 inner_assignments, inner_args, inner_output = convert_output(t) 273 assignments += inner_assignments 274 extra_args += inner_args 275 fields.append(inner_output) 276 fields = [ast.List(fields, Load())] 277 return (assignments, extra_args, self.create_call("_container.tuple_object", fields)) 278 279 # create a function to wrap the call of the lowered op and return 280 # a call to that function 281 wrap_name = self.generate_function_name("_{}_wrapper".format(op_name)) 282 wrap_args = [self.generate_var_name("_arg_{}".format(i)) for i in range(len(py_args))] 283 284 inner_call_args = [] 285 for i in range(len(py_args)): 286 inner_call_args += convert_input(Name(wrap_args[i], Load()), relay_args[i].checked_type) 287 output_assignments, aux_args, output = convert_output(op.checked_type.ret_type) 288 # equiv: _op = tvm.get_global_func(op_name) 289 op_var = self.generate_var_name("_op") 290 op_call = self.create_call("tvm.get_global_func", [Str(op_name)]) 291 op_assign = Assign([Name(op_var, Store())], op_call) 292 # equiv: _op(args) 293 inner_call = self.create_call(op_var, inner_call_args + aux_args) 294 body = output_assignments + [op_assign, ast.Expr(inner_call), Return(output)] 295 wrap_def = self.create_def(wrap_name, wrap_args, body) 296 return wrap_def, self.create_call(wrap_name, py_args) 297 298 def create_match_check(self, pattern: Pattern, data): 299 """Given an ADT match pattern and a (Python) expression pointing to 300 an ADT value, this generates a Python expression that checks if the 301 ADT value matches the given pattern (returning True or False).""" 302 303 # wildcard or var match everything 304 if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)): 305 return NameConstant(True) 306 307 conds = [] 308 309 if isinstance(pattern, relay.PatternConstructor): 310 # constructor patterns check whether the constructors match 311 # and also the matches of any nested patterns 312 313 # equiv: (arg.tag == patern_constructor.tag) 314 conds.append( 315 ast.Compare( 316 ast.Attribute(data, "tag", Load()), 317 [ast.Eq()], 318 [ast.Num(pattern.constructor.tag)], 319 ) 320 ) 321 322 assert isinstance(pattern, (relay.PatternConstructor, relay.PatternTuple)) 323 # now check for any nested patterns 324 for i in range(len(pattern.patterns)): 325 nested_pat = pattern.patterns[i] 326 # can safely skip var or wildcard patterns: they will 327 # never cause a check to fail 328 if not isinstance(nested_pat, relay.PatternConstructor): 329 continue 330 331 # index into the value corresponding to the subpattern 332 field_index = ast.Subscript( 333 ast.Attribute(data, "fields", Load()), ast.Index(Num(i)), Load() 334 ) 335 conds.append(self.create_match_check(nested_pat, field_index)) 336 337 # if we do not need to check nested pattern, just return the single check 338 if len(conds) == 1: 339 return conds[0] 340 # otherwise AND together any nested checks 341 return ast.BoolOp(ast.And(), conds) 342 343 def create_match_clause_body(self, pattern: Pattern, body: Expr): 344 """Given a match clause pattern and a clause body, 345 generates a Python function that when called with an ADT 346 that matches the pattern, returns the result of evaluating 347 the clause body. This function returns a function definition 348 and the name of the generated function.""" 349 350 def collect_var_assignments(pat, val): 351 """This helper function ensures that the pattern is used to 352 properly assign all subfields of the given AST for use 353 in the clause body 354 355 E.g., for PatternConstructor(A, PatternVar(v), PatternWildcard(), 356 PatternConstructor(B, PatternVar(w))) 357 we would want to have 358 v = a.fields[0] 359 w = a.fields[2].fields[0] 360 """ 361 if isinstance(pat, relay.PatternWildcard): 362 return [] 363 if isinstance(pat, relay.PatternVar): 364 return [Assign([self.include_var(pat.var, assign=True)], val)] 365 # constructor pattern: assign each field of the value 366 # based on subpatterns 367 assignments = [] 368 for i in range(len(pat.patterns)): 369 # we want the assignments for val.fields[i] 370 field = ast.Subscript( 371 ast.Attribute(val, "fields", Load()), ast.Index(Num(i)), Load() 372 ) 373 assignments += collect_var_assignments(pat.patterns[i], field) 374 return assignments 375 376 func_name = self.generate_function_name("_match_clause_body") 377 arg_name = self.generate_var_name("_match_clause_body") 378 379 clause_body, defs = self.visit(body) 380 assignments = collect_var_assignments(pattern, Name(arg_name, Load())) 381 382 func_def = self.create_def( 383 func_name, [arg_name], defs + assignments + [Return(clause_body)] 384 ) 385 return (func_def, func_name) 386 387 # Convention for the expr visitor: Each visit function returns a tuple of two members. 388 # 389 # The first is a Python AST comprised of a single *expression* that evaluates to an equivalent 390 # result to the desired Relay expression (and executes all effects in the right order). 391 # 392 # The second is a list of function definition *statements* defining thunks and other 393 # auxiliary functions needed in the translated AST object. The defs in the second object 394 # will always have unique names and will never perform any effects, so as long as they 395 # appear in the Python program before the first statement is executed, there should not 396 # be any problems. 397 398 def visit_var(self, var: Expr): 399 return (self.include_var(var, assign=False), []) 400 401 def visit_global_var(self, gvar: Expr): 402 # we don't need to add numbers to global var names because 403 # the *names* are checked for uniqueness in the mod 404 return (Name(str(gvar.name_hint), Load()), []) 405 406 def visit_let(self, letexp: Expr): 407 # To properly account for scoping and ensure that the entire node produces an expression, 408 # we translate the let binding as a function that we call with the value we intend to bind. 409 # Yes, this is somewhat ugly. 410 """ 411 let var = value in body 412 ======================= 413 def let_thunk(var): 414 return body 415 let_thunk(value) 416 """ 417 bind_body, bind_defs = self.visit(letexp.body) 418 419 func_name = self.generate_function_name("_let_func") 420 binding_func = self.create_def( 421 func_name, [self.get_var_name(letexp.var)], bind_defs + [Return(bind_body)] 422 ) 423 424 # we call the binding func with the intended value for the bound variable 425 426 # special case: if the value is a function literal, we must ensure it can be 427 # recursive by naming it after the var 428 if isinstance(letexp.value, Function): 429 value_def, value_name = self.convert_func_node(letexp.value, letexp.var) 430 return ( 431 self.create_call(func_name, [Name(value_name, Load())]), 432 [value_def, binding_func], 433 ) 434 435 value_body, value_defs = self.visit(letexp.value) 436 value_defs.append(binding_func) 437 binding_call = self.create_call(func_name, [value_body]) 438 return (binding_call, value_defs) 439 440 def visit_tuple(self, tup: Expr): 441 fields, ret_defs = self.convert_fields(tup.fields) 442 fields = [ast.List(fields, Load())] 443 return (self.create_call("_container.tuple_object", fields), ret_defs) 444 445 def visit_tuple_getitem(self, tgi: Expr): 446 tup, tup_defs = self.visit(tgi.tuple_value) 447 ret = ast.Subscript(tup, ast.Index(Num(tgi.index)), Load()) 448 return (ret, tup_defs) 449 450 def visit_if(self, if_block: Expr): 451 cond_body, cond_defs = self.visit(if_block.cond) 452 true_body, true_defs = self.visit(if_block.true_branch) 453 false_body, false_defs = self.visit(if_block.false_branch) 454 455 # need to get the value out of a NDArray to check the condition 456 # equvialent to: val.asnumpy() 457 cond_check = ast.Call(ast.Attribute(cond_body, "asnumpy", Load()), [], []) 458 ret = ast.IfExp(cond_check, true_body, false_body) 459 return (ret, cond_defs + true_defs + false_defs) 460 461 def visit_constant(self, constant: Expr): 462 """Proceeds by converting constant value to a numpy array 463 and converting it to the appropriate value in the generated 464 code (whether it be a Python scalar or a Numpy array)""" 465 value = constant.data.asnumpy() 466 const_expr = ast.Call( 467 ast.Attribute(Name("numpy", Load()), "array", Load()), 468 [self.parse_numpy_array(value)], 469 [ast.keyword("dtype", Str(constant.checked_type.dtype))], 470 ) 471 return (self.create_call("nd.array", [const_expr]), []) 472 473 def visit_function(self, func: Expr): 474 # Python's lambdas are very restrictive, so we do "name" inline functions 475 converted_func, func_name = self.convert_func_node(func) 476 return (Name(func_name, Load()), [converted_func]) 477 478 def visit_call(self, call: Expr): 479 """For calls, we must distinguish between ordinary functions, 480 operators, and constructor calls.""" 481 func = call.op 482 fields, field_defs = self.convert_fields(call.args) 483 484 if isinstance(func, tvm.ir.Op): 485 raise Exception("Operators should have been lowered and eliminated") 486 487 if isinstance(func, relay.Constructor): 488 # produce a constructor value 489 return ( 490 self.create_call( 491 "ConstructorValue", 492 [ast.Num(func.tag), ast.List(fields, Load()), NameConstant(None)], 493 ), 494 field_defs, 495 ) 496 497 # lowered operator: generate a call to a function that gets the PackedFunc 498 # from TVM's registry 499 if isinstance(func, Function) and func.attrs and func.attrs.Primitive.value == 1: 500 op_call_def, op_call = self.create_op_call(func, call.args, fields) 501 return (op_call, field_defs + [op_call_def]) 502 503 # ordinary function 504 converted_func, defs = self.visit(func) 505 defs += field_defs 506 return (ast.Call(converted_func, fields, []), defs) 507 508 def visit_ref_create(self, ref: Expr): 509 val, defs = self.visit(ref.value) 510 return (self.create_call("RefValue", [val]), defs) 511 512 def visit_ref_read(self, read: Expr): 513 ref, defs = self.visit(read.ref) 514 return (ast.Attribute(ref, "value", Load()), defs) 515 516 def visit_ref_write(self, write: Expr): 517 """For writing refs, we wrap the update in a thunk 518 (returning an empty tuple to match Relay's semantics) 519 that we execute at the right time. This ensures such assignments 520 can be properly nested, since assignments are statements 521 in Python but expressions in Relay""" 522 ref, ref_defs = self.visit(write.ref) 523 val, val_defs = self.visit(write.value) 524 thunk_name = self.generate_function_name("_ref_write_thunk") 525 thunk = self.create_def( 526 thunk_name, 527 [], 528 ref_defs 529 + val_defs 530 + [ 531 Assign([ast.Attribute(ref, "value", Store())], val), 532 Return(self.create_call("_container.tuple_object", [])), 533 ], 534 ) 535 return (self.create_call(thunk_name, []), [thunk]) 536 537 def visit_match(self, match: Expr): 538 """For matches, we wrap the entire expression in a thunk 539 because it is easiest to implement them using if statements. 540 For each clause, we generate a function that checks if the 541 pattern matches. If yes, we call a function that assigns 542 the variables appropriately and invokes the clause body.""" 543 data, defs = self.visit(match.data) 544 data_var = self.generate_var_name("_match_data") 545 546 # must ensure the data clause is executed exactly once 547 thunk_body = [Assign([Name(data_var, Store())], data)] 548 for clause in match.clauses: 549 check_expr = self.create_match_check(clause.lhs, Name(data_var, Load())) 550 body_def, body_name = self.create_match_clause_body(clause.lhs, clause.rhs) 551 defs.append(body_def) 552 553 # equiv: if check(data): return body(data) 554 thunk_body.append( 555 ast.If( 556 check_expr, [Return(self.create_call(body_name, [Name(data_var, Load())]))], [] 557 ) 558 ) 559 560 # finally if nothing matches we have a failed assert (should never happen) 561 thunk_body.append(ast.Assert(NameConstant(False), Str("Match was not exhaustive"))) 562 563 thunk_name = self.generate_function_name("_match_thunk") 564 thunk_def = self.create_def(thunk_name, [], defs + thunk_body) 565 return (self.create_call(thunk_name, []), [thunk_def]) 566 567 # these are both handled in the "call" case 568 def visit_constructor(self, _): 569 pass 570 571 def visit_op(self, _): 572 pass 573 574 575def to_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")): 576 """Converts the given Relay expression into a Python script (as a Python AST object). 577 For easiest debugging, import the astor package and use to_source().""" 578 mod = mod if mod is not None else tvm.IRModule() 579 converter = PythonConverter(mod, target) 580 return converter.convert(expr) 581 582 583def run_as_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")): 584 """Converts the given Relay expression into a Python script and 585 executes it.""" 586 mod = mod if mod is not None else tvm.IRModule() 587 py_ast = to_python(expr, mod, target) 588 code = compile(py_ast, "<string>", "exec") 589 var_map = {OUTPUT_VAR_NAME: None} 590 # pylint: disable=exec-used 591 exec(code, var_map, var_map) 592 return var_map[OUTPUT_VAR_NAME] 593