1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Utility for converting Relay code into a Python script with equivalent semantics""" 18import ast 19from ast import alias, Assign, Load, Name, NameConstant, Num, Return, Store, Str 20import re 21 22import tvm 23from tvm import relay 24from tvm.relay.adt import Pattern 25from tvm.relay.backend import compile_engine 26from tvm.relay.expr import Expr, Function, GlobalVar, Var 27from tvm.relay.expr_functor import ExprFunctor 28 29OUTPUT_VAR_NAME = '_py_out' 30 31# corresponds to: 32# import numpy 33# import tvm 34# from tvm import relay 35# from tvm.relay.backend.interpreter import RefValue, TupleValue, TensorValue, ConstructorValue 36PROLOGUE = [ 37 ast.Import([alias('numpy', None)]), 38 ast.Import([alias('tvm', None)]), 39 ast.ImportFrom('tvm', [alias('relay', None)], 0), 40 ast.ImportFrom('tvm.relay.backend.interpreter', 41 [alias('RefValue', None), 42 alias('TupleValue', None), 43 alias('TensorValue', None), 44 alias('ConstructorValue', None)], 45 0) 46] 47 48class PythonConverter(ExprFunctor): 49 """Functor for translating Relay programs into Python ASTs.""" 50 51 def __init__(self, mod, target) -> None: 52 super().__init__() 53 self.mod = mod 54 self.tgt = target 55 self.engine = compile_engine.get() 56 self.fun_no = 0 57 self.var_no = 0 58 self.var_map = {} 59 60 61 def convert(self, prog: Expr): 62 """This method converts the passed Relay expression into a Python 63 AST object with equivalent semantics. 64 65 The Python AST can be executed using exec(); it can be turned 66 into text and inspected using astor. 67 """ 68 optimized = self.optimize(prog) 69 70 # start with conversion prelude (imports) and convert global defs 71 body = [] 72 body += PROLOGUE 73 body += self.convert_module() 74 75 prog_body, extra_defs = self.visit(optimized) 76 body += extra_defs 77 78 # we finally must assign the final expression to the output var 79 # so it can be read after running EXEC 80 body.append(Assign([Name(OUTPUT_VAR_NAME, Store())], prog_body)) 81 82 return ast.fix_missing_locations(ast.Module(body=body)) 83 84 85 def optimize(self, prog: Expr): 86 """Performs optimizations necessary to be able to generate code for prog.""" 87 # unwrap tuple wrappers (some op calls produce them) 88 unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper) else prog 89 assert relay.analysis.well_formed(unwrapped) 90 mod = self.mod.from_expr(unwrapped, self.mod.functions, self.mod.type_definitions) 91 92 # necessary pass: SimplifyInference (otherwise we can't generate code for some operators) 93 # and fusion (to get primitive functions) 94 opts = relay.transform.Sequential([relay.transform.SimplifyInference(), 95 relay.transform.FuseOps(fuse_opt_level=0)]) 96 mod = opts(mod) 97 optimized = mod['main'] 98 return optimized if isinstance(unwrapped, Function) else optimized.body 99 100 101 def sanitize(self, name: str) -> str: 102 """Removes any invalid characters (only underscores, numbers, and letters permitted) 103 from the given name. Since we append a number and underscore to var names anyway, 104 it doesn't matter if the name is the empty string.""" 105 return re.sub(r'\W', '', name) 106 107 108 def generate_var_name(self, name_hint: str) -> str: 109 """Generates a unique variable name starting from the hint.""" 110 name = '{}_var_{}'.format(self.sanitize(name_hint), self.var_no) 111 self.var_no += 1 112 return name 113 114 115 def generate_function_name(self, name_hint: str) -> str: 116 """Generates a unique function name starting from the hint.""" 117 name = '{}_fun_{}'.format(self.sanitize(name_hint), self.fun_no) 118 self.fun_no += 1 119 return name 120 121 122 def get_var_name(self, var: Expr) -> str: 123 """Returns the var name for the given Realy variable.""" 124 if var in self.var_map: 125 return self.var_map[var] 126 name = self.generate_var_name(var.name_hint) 127 self.var_map[var] = name 128 return name 129 130 131 def include_var(self, var: Expr, assign=False): 132 """Returns a variable AST node for the given Relay var depending on 133 whether it must appear in an assignment or not.""" 134 name = self.get_var_name(var) 135 return Name(name, Store() if assign else Load()) 136 137 138 def parse_name(self, name: str): 139 """Given the name of a Python method with dots (e.g., 'relay.var'), 140 returns an appropriate AST object corresponding to that name.""" 141 attributes = name.split('.') 142 ret = Name(attributes[0], Load()) 143 for i in range(len(attributes) - 1): 144 ret = ast.Attribute(ret, attributes[i+1], Load()) 145 return ret 146 147 148 def parse_numpy_array(self, arr): 149 """Given a Numpy array, produces an appropriate Python array 150 or numerical literal representing its contents.""" 151 parse_single = lambda i: NameConstant(i) if isinstance(i, bool) else Num(i) 152 if arr.ndim == 0: 153 return parse_single(arr.item()) 154 if arr.ndim == 1: 155 return ast.List([parse_single(i.item()) for i in arr], Load()) 156 157 elts = [] 158 for row in arr: 159 elts.append(self.parse_numpy_array(row)) 160 return ast.List(elts, Load()) 161 162 163 def convert_fields(self, fields: [Expr]): 164 """Given a list of call args or tuple fields, converts 165 each and returns their ASTs and their defs lists (in order).""" 166 bodies = [] 167 defs = [] 168 for field in fields: 169 member_body, member_defs = self.visit(field) 170 bodies.append(member_body) 171 defs += member_defs 172 return (bodies, defs) 173 174 175 def convert_to_thunk(self, name_hint: str, expr: Expr): 176 """Wraps the passed expression in a thunk.""" 177 body, defs = self.visit(expr) 178 thunk_name = self.generate_function_name(name_hint) 179 thunk = self.create_def(thunk_name, [], defs + [Return(body)]) 180 return (thunk, thunk_name) 181 182 183 def convert_func_node(self, func: Function, name_var=None): 184 """Converts the given Relay function into a Python function, with 185 special for named functions (locally or globally)""" 186 if name_var is None: 187 func_name = self.generate_function_name('_anon_func') 188 if isinstance(name_var, GlobalVar): 189 func_name = name_var.name_hint 190 if isinstance(name_var, Var): 191 func_name = self.get_var_name(name_var) 192 193 var_names = [self.get_var_name(var) for var in func.params] 194 body, defs = self.visit(func.body) 195 ret = self.create_def(func_name, var_names, defs + [Return(body)]) 196 return (ret, func_name) 197 198 199 def convert_module(self): 200 """Converts all the global functions defined in the module and returns 201 them as a list of definitions""" 202 defs = [] 203 for var, func in self.mod.functions.items(): 204 # optimize the definition so any operators used are lowered 205 opt_func = self.optimize(func) 206 try: 207 converted_func, _ = self.convert_func_node(opt_func, var) 208 defs.append(converted_func) 209 except TypeError: 210 # TODO(wweic): fix conversion for Any 211 pass 212 return defs 213 214 215 def create_call(self, func_name: str, arguments): 216 """Creates a simple function call.""" 217 return ast.Call(self.parse_name(func_name), arguments, []) 218 219 220 def create_def(self, func_name: str, arguments: [str], body): 221 """Wrapper over function definition AST node, whose constructor is inconvenient.""" 222 return ast.FunctionDef( 223 func_name, 224 ast.arguments([ast.arg(argument, None) 225 for argument in arguments], 226 None, [], [], None, []), 227 body, [], None) 228 229 230 def create_op_call(self, op: Function, relay_args, py_args): 231 """Lowers the passed primitive function, registers it in TVM's 232 global compiler, and produces a call to the lowered function in 233 the generated Python code.""" 234 235 # compile the function and register globally 236 cc_key = compile_engine.CCacheKey(op, self.tgt) 237 func_hash = relay.analysis.structural_hash(op) 238 op_name = '_lowered_op_{}'.format(func_hash) 239 if not tvm.get_global_func(op_name, allow_missing=True): 240 jitted = self.engine.jit(cc_key, self.tgt) 241 tvm.register_func(op_name, jitted) 242 243 def convert_input(py_input, arg_type): 244 """Use the types of the function arguments to determine whether we expect 245 a tensor or tuple (returns list of inputs to the lowered op call)""" 246 # equivalent: input.data 247 if isinstance(arg_type, relay.TensorType): 248 return [ast.Attribute(py_input, 'data', Load())] 249 assert isinstance(arg_type, relay.TupleType) 250 # convert each input.fields[i] 251 ret = [] 252 for i in range(len(arg_type.fields)): 253 ret += convert_input( 254 ast.Subscript( 255 ast.Attribute(py_input, 'fields', Load()), 256 ast.Index(Num(i)), Load()), 257 arg_type.fields[i]) 258 return ret 259 260 def convert_output(ret_type): 261 """Use the function return type to produce auxiliary variables to store outputs. 262 Returns ([assignments of output vars], [extra arguments to pass to op call], 263 expression collecting output)""" 264 if isinstance(ret_type, relay.TensorType): 265 output_var_name = self.generate_var_name('_out') 266 output_var = Name(output_var_name, Load()) 267 shape = ast.Tuple([Num(dim) for dim in ret_type.concrete_shape], Load()) 268 # create a new TensorValue of the right shape and dtype 269 assign_output = Assign( 270 [Name(output_var_name, Store())], 271 self.create_call('TensorValue', [ 272 self.create_call('numpy.empty', [shape, Str(ret_type.dtype)]) 273 ])) 274 # we pass the data field as an argument 275 extra_arg = ast.Attribute(output_var, 'data', Load()) 276 return ([assign_output], [extra_arg], output_var) 277 assert isinstance(ret_type, relay.TupleType) 278 assignments = [] 279 extra_args = [] 280 fields = [] 281 for t in ret_type.fields: 282 inner_assignments, inner_args, inner_output = convert_output(t) 283 assignments += inner_assignments 284 extra_args += inner_args 285 fields.append(inner_output) 286 return (assignments, extra_args, self.create_call('TupleValue', fields)) 287 288 # create a function to wrap the call of the lowered op and return 289 # a call to that function 290 wrap_name = self.generate_function_name('_{}_wrapper'.format(op_name)) 291 wrap_args = [self.generate_var_name('_arg_{}'.format(i)) for i in range(len(py_args))] 292 293 inner_call_args = [] 294 for i in range(len(py_args)): 295 inner_call_args += convert_input(Name(wrap_args[i], Load()), 296 relay_args[i].checked_type) 297 output_assignments, aux_args, output = convert_output(op.checked_type.ret_type) 298 # equiv: _op = tvm.get_global_func(op_name) 299 op_var = self.generate_var_name('_op') 300 op_call = self.create_call('tvm.get_global_func', [Str(op_name)]) 301 op_assign = Assign([Name(op_var, Store())], op_call) 302 # equiv: _op(args) 303 inner_call = self.create_call(op_var, inner_call_args + aux_args) 304 body = output_assignments + [op_assign, ast.Expr(inner_call), Return(output)] 305 wrap_def = self.create_def(wrap_name, wrap_args, body) 306 return wrap_def, self.create_call(wrap_name, py_args) 307 308 309 def create_match_check(self, pattern: Pattern, data): 310 """Given an ADT match pattern and a (Python) expression pointing to 311 an ADT value, this generates a Python expression that checks if the 312 ADT value matches the given pattern (returning True or False).""" 313 314 # wildcard or var match everything 315 if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)): 316 return NameConstant(True) 317 318 conds = [] 319 320 if isinstance(pattern, relay.PatternConstructor): 321 # constructor patterns check whether the constructors match 322 # and also the matches of any nested patterns 323 324 # equiv: (arg.tag == patern_constructor.tag) 325 conds.append(ast.Compare(ast.Attribute(data, 'tag', Load()), 326 [ast.Eq()], 327 [ast.Num(pattern.constructor.tag)])) 328 329 assert isinstance(pattern, (relay.PatternConstructor, relay.PatternTuple)) 330 # now check for any nested patterns 331 for i in range(len(pattern.patterns)): 332 nested_pat = pattern.patterns[i] 333 # can safely skip var or wildcard patterns: they will 334 # never cause a check to fail 335 if not isinstance(nested_pat, relay.PatternConstructor): 336 continue 337 338 # index into the value corresponding to the subpattern 339 field_index = ast.Subscript(ast.Attribute(data, 'fields', Load()), 340 ast.Index(Num(i)), Load()) 341 conds.append(self.create_match_check(nested_pat, field_index)) 342 343 # if we do not need to check nested pattern, just return the single check 344 if len(conds) == 1: 345 return conds[0] 346 # otherwise AND together any nested checks 347 return ast.BoolOp(ast.And(), conds) 348 349 350 def create_match_clause_body(self, pattern: Pattern, body: Expr): 351 """Given a match clause pattern and a clause body, 352 generates a Python function that when called with an ADT 353 that matches the pattern, returns the result of evaluating 354 the clause body. This function returns a function definition 355 and the name of the generated function.""" 356 357 def collect_var_assignments(pat, val): 358 """This helper function ensures that the pattern is used to 359 properly assign all subfields of the given AST for use 360 in the clause body 361 362 E.g., for PatternConstructor(A, PatternVar(v), PatternWildcard(), 363 PatternConstructor(B, PatternVar(w))) 364 we would want to have 365 v = a.fields[0] 366 w = a.fields[2].fields[0] 367 """ 368 if isinstance(pat, relay.PatternWildcard): 369 return [] 370 if isinstance(pat, relay.PatternVar): 371 return [Assign([self.include_var(pat.var, assign=True)], val)] 372 # constructor pattern: assign each field of the value 373 # based on subpatterns 374 assignments = [] 375 for i in range(len(pat.patterns)): 376 # we want the assignments for val.fields[i] 377 field = ast.Subscript(ast.Attribute(val, 'fields', Load()), 378 ast.Index(Num(i)), Load()) 379 assignments += collect_var_assignments(pat.patterns[i], field) 380 return assignments 381 382 func_name = self.generate_function_name('_match_clause_body') 383 arg_name = self.generate_var_name('_match_clause_body') 384 385 clause_body, defs = self.visit(body) 386 assignments = collect_var_assignments(pattern, Name(arg_name, Load())) 387 388 func_def = self.create_def(func_name, [arg_name], 389 defs + assignments + [Return(clause_body)]) 390 return (func_def, func_name) 391 392 393 # Convention for the expr visitor: Each visit function returns a tuple of two members. 394 # 395 # The first is a Python AST comprised of a single *expression* that evaluates to an equivalent 396 # result to the desired Relay expression (and executes all effects in the right order). 397 # 398 # The second is a list of function definition *statements* defining thunks and other 399 # auxiliary functions needed in the translated AST object. The defs in the second object 400 # will always have unique names and will never perform any effects, so as long as they 401 # appear in the Python program before the first statement is executed, there should not 402 # be any problems. 403 404 def visit_var(self, var: Expr): 405 return (self.include_var(var, assign=False), []) 406 407 408 def visit_global_var(self, gvar: Expr): 409 # we don't need to add numbers to global var names because 410 # the *names* are checked for uniqueness in the mod 411 return (Name(gvar.name_hint, Load()), []) 412 413 414 def visit_let(self, letexp: Expr): 415 # To properly account for scoping and ensure that the entire node produces an expression, 416 # we translate the let binding as a function that we call with the value we intend to bind. 417 # Yes, this is somewhat ugly. 418 """ 419 let var = value in body 420 ======================= 421 def let_thunk(var): 422 return body 423 let_thunk(value) 424 """ 425 bind_body, bind_defs = self.visit(letexp.body) 426 427 func_name = self.generate_function_name('_let_func') 428 binding_func = self.create_def(func_name, [self.get_var_name(letexp.var)], 429 bind_defs + [Return(bind_body)]) 430 431 # we call the binding func with the intended value for the bound variable 432 433 # special case: if the value is a function literal, we must ensure it can be 434 # recursive by naming it after the var 435 if isinstance(letexp.value, Function): 436 value_def, value_name = self.convert_func_node(letexp.value, letexp.var) 437 return (self.create_call(func_name, [Name(value_name, Load())]), 438 [value_def, binding_func]) 439 440 value_body, value_defs = self.visit(letexp.value) 441 value_defs.append(binding_func) 442 binding_call = self.create_call(func_name, [value_body]) 443 return (binding_call, value_defs) 444 445 446 def visit_tuple(self, tup: Expr): 447 fields, ret_defs = self.convert_fields(tup.fields) 448 return (self.create_call('TupleValue', fields), ret_defs) 449 450 451 def visit_tuple_getitem(self, tgi: Expr): 452 tup, tup_defs = self.visit(tgi.tuple_value) 453 ret = ast.Subscript(tup, ast.Index(Num(tgi.index)), Load()) 454 return (ret, tup_defs) 455 456 457 def visit_if(self, if_block: Expr): 458 cond_body, cond_defs = self.visit(if_block.cond) 459 true_body, true_defs = self.visit(if_block.true_branch) 460 false_body, false_defs = self.visit(if_block.false_branch) 461 462 # need to get the value out of a TensorValue to check the condition 463 # equvialent to: val.asnumpy() 464 cond_check = ast.Call(ast.Attribute(cond_body, 'asnumpy', Load()), [], []) 465 ret = ast.IfExp(cond_check, true_body, false_body) 466 return (ret, cond_defs + true_defs + false_defs) 467 468 469 def visit_constant(self, constant: Expr): 470 """Proceeds by converting constant value to a numpy array 471 and converting it to the appropriate value in the generated 472 code (whether it be a Python scalar or a Numpy array)""" 473 value = constant.data.asnumpy() 474 const_expr = ast.Call(ast.Attribute(Name('numpy', Load()), 'array', Load()), 475 [self.parse_numpy_array(value)], 476 [ast.keyword('dtype', Str(constant.checked_type.dtype))]) 477 return (self.create_call('TensorValue', [const_expr]), []) 478 479 480 def visit_function(self, func: Expr): 481 # Python's lambdas are very restrictive, so we do "name" inline functions 482 converted_func, func_name = self.convert_func_node(func) 483 return (Name(func_name, Load()), [converted_func]) 484 485 486 def visit_call(self, call: Expr): 487 """For calls, we must distinguish between ordinary functions, 488 operators, and constructor calls.""" 489 func = call.op 490 fields, field_defs = self.convert_fields(call.args) 491 492 if isinstance(func, relay.Op): 493 raise Exception('Operators should have been lowered and eliminated') 494 495 if isinstance(func, relay.Constructor): 496 # produce a constructor value 497 return (self.create_call('ConstructorValue', 498 [ast.Num(func.tag), 499 ast.List(fields, Load()), 500 NameConstant(None)]), 501 field_defs) 502 503 # lowered operator: generate a call to a function that gets the PackedFunc 504 # from TVM's registry 505 if isinstance(func, Function) and func.attrs and func.attrs.Primitive.value == 1: 506 op_call_def, op_call = self.create_op_call(func, call.args, fields) 507 return (op_call, field_defs + [op_call_def]) 508 509 # ordinary function 510 converted_func, defs = self.visit(func) 511 defs += field_defs 512 return (ast.Call(converted_func, fields, []), defs) 513 514 515 def visit_ref_create(self, ref: Expr): 516 val, defs = self.visit(ref.value) 517 return (self.create_call('RefValue', [val]), defs) 518 519 520 def visit_ref_read(self, read: Expr): 521 ref, defs = self.visit(read.ref) 522 return (ast.Attribute(ref, 'value', Load()), defs) 523 524 525 def visit_ref_write(self, write: Expr): 526 """For writing refs, we wrap the update in a thunk 527 (returning an empty tuple to match Relay's semantics) 528 that we execute at the right time. This ensures such assignments 529 can be properly nested, since assignments are statements 530 in Python but expressions in Relay""" 531 ref, ref_defs = self.visit(write.ref) 532 val, val_defs = self.visit(write.value) 533 thunk_name = self.generate_function_name('_ref_write_thunk') 534 thunk = self.create_def( 535 thunk_name, [], 536 ref_defs + val_defs + [ 537 Assign([ast.Attribute(ref, 'value', Store())], val), 538 Return(self.create_call('TupleValue', [])) 539 ]) 540 return (self.create_call(thunk_name, []), [thunk]) 541 542 543 def visit_match(self, match: Expr): 544 """For matches, we wrap the entire expression in a thunk 545 because it is easiest to implement them using if statements. 546 For each clause, we generate a function that checks if the 547 pattern matches. If yes, we call a function that assigns 548 the variables appropriately and invokes the clause body.""" 549 data, defs = self.visit(match.data) 550 data_var = self.generate_var_name('_match_data') 551 552 # must ensure the data clause is executed exactly once 553 thunk_body = [Assign([Name(data_var, Store())], data)] 554 for clause in match.clauses: 555 check_expr = self.create_match_check(clause.lhs, Name(data_var, Load())) 556 body_def, body_name = self.create_match_clause_body(clause.lhs, clause.rhs) 557 defs.append(body_def) 558 559 # equiv: if check(data): return body(data) 560 thunk_body.append(ast.If( 561 check_expr, 562 [Return(self.create_call(body_name, [Name(data_var, Load())]))], 563 [] 564 )) 565 566 # finally if nothing matches we have a failed assert (should never happen) 567 thunk_body.append(ast.Assert(NameConstant(False), Str('Match was not exhaustive'))) 568 569 thunk_name = self.generate_function_name('_match_thunk') 570 thunk_def = self.create_def(thunk_name, [], defs + thunk_body) 571 return (self.create_call(thunk_name, []), [thunk_def]) 572 573 574 # these are both handled in the "call" case 575 def visit_constructor(self, _): 576 pass 577 def visit_op(self, _): 578 pass 579 580 581def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')): 582 """Converts the given Relay expression into a Python script (as a Python AST object). 583 For easiest debugging, import the astor package and use to_source().""" 584 mod = mod if mod is not None else relay.Module() 585 converter = PythonConverter(mod, target) 586 return converter.convert(expr) 587 588 589def run_as_python(expr: Expr, mod=None, target=tvm.target.create('llvm')): 590 """Converts the given Relay expression into a Python script and 591 executes it.""" 592 mod = mod if mod is not None else relay.Module() 593 py_ast = to_python(expr, mod, target) 594 code = compile(py_ast, '<string>', 'exec') 595 var_map = { 596 OUTPUT_VAR_NAME : None 597 } 598 #pylint: disable=exec-used 599 exec(code, var_map, var_map) 600 return var_map[OUTPUT_VAR_NAME] 601