1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Utility for converting Relay code into a Python script with equivalent semantics"""
18import ast
19from ast import alias, Assign, Load, Name, NameConstant, Num, Return, Store, Str
20import re
21
22import tvm
23from tvm import relay
24from tvm.relay.adt import Pattern
25from tvm.relay.backend import compile_engine
26from tvm.relay.expr import Expr, GlobalVar, Var
27from tvm.relay.function import Function
28from tvm.relay.expr_functor import ExprFunctor
29
30OUTPUT_VAR_NAME = "_py_out"
31
32# corresponds to:
33#     import numpy
34#     import tvm
35#     from tvm import relay
36#     from tvm import nd
37#     from tvm.runtime import import container as _container
38#     from tvm.relay.backend.interpreter import RefValue, ConstructorValue
39PROLOGUE = [
40    ast.Import([alias("numpy", None)]),
41    ast.Import([alias("tvm", None)]),
42    ast.ImportFrom("tvm", [alias("relay", None)], 0),
43    ast.ImportFrom("tvm", [alias("nd", None)], 0),
44    ast.ImportFrom("tvm.runtime", [alias("container", "_container")], 0),
45    ast.ImportFrom(
46        "tvm.relay.backend.interpreter",
47        [alias("RefValue", None), alias("ConstructorValue", None)],
48        0,
49    ),
50]
51
52
53class PythonConverter(ExprFunctor):
54    """Functor for translating Relay programs into Python ASTs."""
55
56    def __init__(self, mod, target) -> None:
57        super().__init__()
58        self.mod = mod
59        self.tgt = target
60        self.engine = compile_engine.get()
61        self.fun_no = 0
62        self.var_no = 0
63        self.var_map = {}
64
65    def convert(self, prog: Expr):
66        """This method converts the passed Relay expression into a Python
67        AST object with equivalent semantics.
68
69        The Python AST can be executed using exec(); it can be turned
70        into text and inspected using astor.
71        """
72        optimized = self.optimize(prog)
73
74        # start with conversion prelude (imports) and convert global defs
75        body = []
76        body += PROLOGUE
77        body += self.convert_module()
78
79        prog_body, extra_defs = self.visit(optimized)
80        body += extra_defs
81
82        # we finally must assign the final expression to the output var
83        # so it can be read after running EXEC
84        body.append(Assign([Name(OUTPUT_VAR_NAME, Store())], prog_body))
85
86        return ast.fix_missing_locations(ast.Module(body=body))
87
88    def optimize(self, prog: Expr):
89        """Performs optimizations necessary to be able to generate code for prog."""
90        # unwrap tuple wrappers (some op calls produce them)
91        unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper) else prog
92        assert relay.analysis.well_formed(unwrapped)
93        mod = self.mod.from_expr(unwrapped, self.mod.functions, self.mod.type_definitions)
94
95        # necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
96        # and fusion (to get primitive functions)
97        opts = tvm.transform.Sequential(
98            [relay.transform.SimplifyInference(), relay.transform.FuseOps(fuse_opt_level=0)]
99        )
100        mod = opts(mod)
101        optimized = mod["main"]
102        return optimized if isinstance(unwrapped, Function) else optimized.body
103
104    def sanitize(self, name: str) -> str:
105        """Removes any invalid characters (only underscores, numbers, and letters permitted)
106        from the given name. Since we append a number and underscore to var names anyway,
107        it doesn't matter if the name is the empty string."""
108        return re.sub(r"\W", "", name)
109
110    def generate_var_name(self, name_hint: str) -> str:
111        """Generates a unique variable name starting from the hint."""
112        name = "{}_var_{}".format(self.sanitize(name_hint), self.var_no)
113        self.var_no += 1
114        return name
115
116    def generate_function_name(self, name_hint: str) -> str:
117        """Generates a unique function name starting from the hint."""
118        name = "{}_fun_{}".format(self.sanitize(name_hint), self.fun_no)
119        self.fun_no += 1
120        return name
121
122    def get_var_name(self, var: Expr) -> str:
123        """Returns the var name for the given Realy variable."""
124        if var in self.var_map:
125            return self.var_map[var]
126        name = self.generate_var_name(var.name_hint)
127        self.var_map[var] = name
128        return name
129
130    def include_var(self, var: Expr, assign=False):
131        """Returns a variable AST node for the given Relay var depending on
132        whether it must appear in an assignment or not."""
133        name = self.get_var_name(var)
134        return Name(name, Store() if assign else Load())
135
136    def parse_name(self, name: str):
137        """Given the name of a Python method with dots (e.g., 'relay.var'),
138        returns an appropriate AST object corresponding to that name."""
139        attributes = name.split(".")
140        ret = Name(attributes[0], Load())
141        for i in range(len(attributes) - 1):
142            ret = ast.Attribute(ret, attributes[i + 1], Load())
143        return ret
144
145    def parse_numpy_array(self, arr):
146        """Given a Numpy array, produces an appropriate Python array
147        or numerical literal representing its contents."""
148        parse_single = lambda i: NameConstant(i) if isinstance(i, bool) else Num(i)
149        if arr.ndim == 0:
150            return parse_single(arr.item())
151        if arr.ndim == 1:
152            return ast.List([parse_single(i.item()) for i in arr], Load())
153
154        elts = []
155        for row in arr:
156            elts.append(self.parse_numpy_array(row))
157        return ast.List(elts, Load())
158
159    def convert_fields(self, fields: [Expr]):
160        """Given a list of call args or tuple fields, converts
161        each and returns their ASTs and their defs lists (in order)."""
162        bodies = []
163        defs = []
164        for field in fields:
165            member_body, member_defs = self.visit(field)
166            bodies.append(member_body)
167            defs += member_defs
168        return (bodies, defs)
169
170    def convert_to_thunk(self, name_hint: str, expr: Expr):
171        """Wraps the passed expression in a thunk."""
172        body, defs = self.visit(expr)
173        thunk_name = self.generate_function_name(name_hint)
174        thunk = self.create_def(thunk_name, [], defs + [Return(body)])
175        return (thunk, thunk_name)
176
177    def convert_func_node(self, func: Function, name_var=None):
178        """Converts the given Relay function into a Python function, with
179        special for named functions (locally or globally)"""
180        if name_var is None:
181            func_name = self.generate_function_name("_anon_func")
182        if isinstance(name_var, GlobalVar):
183            func_name = str(name_var.name_hint)
184        if isinstance(name_var, Var):
185            func_name = self.get_var_name(name_var)
186
187        var_names = [self.get_var_name(var) for var in func.params]
188        body, defs = self.visit(func.body)
189        ret = self.create_def(func_name, var_names, defs + [Return(body)])
190        return (ret, func_name)
191
192    def convert_module(self):
193        """Converts all the global functions defined in the module and returns
194        them as a list of definitions"""
195        defs = []
196        for var, func in self.mod.functions.items():
197            # optimize the definition so any operators used are lowered
198            opt_func = self.optimize(func)
199            try:
200                converted_func, _ = self.convert_func_node(opt_func, var)
201                defs.append(converted_func)
202            except TypeError:
203                # TODO(wweic): fix conversion for Any
204                pass
205        return defs
206
207    def create_call(self, func_name: str, arguments):
208        """Creates a simple function call."""
209        return ast.Call(self.parse_name(func_name), arguments, [])
210
211    def create_def(self, func_name: str, arguments: [str], body):
212        """Wrapper over function definition AST node, whose constructor is inconvenient."""
213        return ast.FunctionDef(
214            func_name,
215            ast.arguments(
216                [ast.arg(argument, None) for argument in arguments], None, [], [], None, []
217            ),
218            body,
219            [],
220            None,
221        )
222
223    def create_op_call(self, op: Function, relay_args, py_args):
224        """Lowers the passed primitive function, registers it in TVM's
225        global compiler, and produces a call to the lowered function in
226        the generated Python code."""
227
228        # compile the function and register globally
229        cc_key = compile_engine.CCacheKey(op, self.tgt)
230        func_hash = tvm.ir.structural_hash(op)
231        op_name = "_lowered_op_{}".format(func_hash)
232        if not tvm.get_global_func(op_name, allow_missing=True):
233            jitted = self.engine.jit(cc_key, self.tgt)
234            tvm.register_func(op_name, jitted)
235
236        def convert_input(py_input, arg_type):
237            """Use the types of the function arguments to determine whether we expect
238            a tensor or tuple (returns list of inputs to the lowered op call)"""
239            # equivalent: input.data
240            if isinstance(arg_type, relay.TensorType):
241                return [py_input]
242            assert isinstance(arg_type, relay.TupleType)
243            # convert each input.fields[i]
244            ret = []
245            for i in range(len(arg_type.fields)):
246                ret += convert_input(
247                    ast.Subscript(py_input, ast.Index(Num(i)), Load()), arg_type.fields[i]
248                )
249            return ret
250
251        def convert_output(ret_type):
252            """Use the function return type to produce auxiliary variables to store outputs.
253            Returns ([assignments of output vars], [extra arguments to pass to op call],
254            expression collecting output)"""
255            if isinstance(ret_type, relay.TensorType):
256                output_var_name = self.generate_var_name("_out")
257                output_var = Name(output_var_name, Load())
258                shape = ast.Tuple([Num(dim) for dim in ret_type.concrete_shape], Load())
259                # create a new NDArray of the right shape and dtype
260                assign_output = Assign(
261                    [Name(output_var_name, Store())],
262                    self.create_call(
263                        "nd.array", [self.create_call("numpy.empty", [shape, Str(ret_type.dtype)])]
264                    ),
265                )
266                return ([assign_output], [output_var], output_var)
267            assert isinstance(ret_type, relay.TupleType)
268            assignments = []
269            extra_args = []
270            fields = []
271            for t in ret_type.fields:
272                inner_assignments, inner_args, inner_output = convert_output(t)
273                assignments += inner_assignments
274                extra_args += inner_args
275                fields.append(inner_output)
276            fields = [ast.List(fields, Load())]
277            return (assignments, extra_args, self.create_call("_container.tuple_object", fields))
278
279        # create a function to wrap the call of the lowered op and return
280        # a call to that function
281        wrap_name = self.generate_function_name("_{}_wrapper".format(op_name))
282        wrap_args = [self.generate_var_name("_arg_{}".format(i)) for i in range(len(py_args))]
283
284        inner_call_args = []
285        for i in range(len(py_args)):
286            inner_call_args += convert_input(Name(wrap_args[i], Load()), relay_args[i].checked_type)
287        output_assignments, aux_args, output = convert_output(op.checked_type.ret_type)
288        # equiv: _op = tvm.get_global_func(op_name)
289        op_var = self.generate_var_name("_op")
290        op_call = self.create_call("tvm.get_global_func", [Str(op_name)])
291        op_assign = Assign([Name(op_var, Store())], op_call)
292        # equiv: _op(args)
293        inner_call = self.create_call(op_var, inner_call_args + aux_args)
294        body = output_assignments + [op_assign, ast.Expr(inner_call), Return(output)]
295        wrap_def = self.create_def(wrap_name, wrap_args, body)
296        return wrap_def, self.create_call(wrap_name, py_args)
297
298    def create_match_check(self, pattern: Pattern, data):
299        """Given an ADT match pattern and a (Python) expression pointing to
300        an ADT value, this generates a Python expression that checks if the
301        ADT value matches the given pattern (returning True or False)."""
302
303        # wildcard or var match everything
304        if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)):
305            return NameConstant(True)
306
307        conds = []
308
309        if isinstance(pattern, relay.PatternConstructor):
310            # constructor patterns check whether the constructors match
311            # and also the matches of any nested patterns
312
313            # equiv: (arg.tag == patern_constructor.tag)
314            conds.append(
315                ast.Compare(
316                    ast.Attribute(data, "tag", Load()),
317                    [ast.Eq()],
318                    [ast.Num(pattern.constructor.tag)],
319                )
320            )
321
322        assert isinstance(pattern, (relay.PatternConstructor, relay.PatternTuple))
323        # now check for any nested patterns
324        for i in range(len(pattern.patterns)):
325            nested_pat = pattern.patterns[i]
326            # can safely skip var or wildcard patterns: they will
327            # never cause a check to fail
328            if not isinstance(nested_pat, relay.PatternConstructor):
329                continue
330
331            # index into the value corresponding to the subpattern
332            field_index = ast.Subscript(
333                ast.Attribute(data, "fields", Load()), ast.Index(Num(i)), Load()
334            )
335            conds.append(self.create_match_check(nested_pat, field_index))
336
337        # if we do not need to check nested pattern, just return the single check
338        if len(conds) == 1:
339            return conds[0]
340        # otherwise AND together any nested checks
341        return ast.BoolOp(ast.And(), conds)
342
343    def create_match_clause_body(self, pattern: Pattern, body: Expr):
344        """Given a match clause pattern and a clause body,
345        generates a Python function that when called with an ADT
346        that matches the pattern, returns the result of evaluating
347        the clause body. This function returns a function definition
348        and the name of the generated function."""
349
350        def collect_var_assignments(pat, val):
351            """This helper function ensures that the pattern is used to
352            properly assign all subfields of the given AST for use
353            in the clause body
354
355            E.g., for PatternConstructor(A, PatternVar(v), PatternWildcard(),
356            PatternConstructor(B, PatternVar(w)))
357            we would want to have
358            v = a.fields[0]
359            w = a.fields[2].fields[0]
360            """
361            if isinstance(pat, relay.PatternWildcard):
362                return []
363            if isinstance(pat, relay.PatternVar):
364                return [Assign([self.include_var(pat.var, assign=True)], val)]
365            # constructor pattern: assign each field of the value
366            # based on subpatterns
367            assignments = []
368            for i in range(len(pat.patterns)):
369                # we want the assignments for val.fields[i]
370                field = ast.Subscript(
371                    ast.Attribute(val, "fields", Load()), ast.Index(Num(i)), Load()
372                )
373                assignments += collect_var_assignments(pat.patterns[i], field)
374            return assignments
375
376        func_name = self.generate_function_name("_match_clause_body")
377        arg_name = self.generate_var_name("_match_clause_body")
378
379        clause_body, defs = self.visit(body)
380        assignments = collect_var_assignments(pattern, Name(arg_name, Load()))
381
382        func_def = self.create_def(
383            func_name, [arg_name], defs + assignments + [Return(clause_body)]
384        )
385        return (func_def, func_name)
386
387    # Convention for the expr visitor: Each visit function returns a tuple of two members.
388    #
389    # The first is a Python AST comprised of a single *expression* that evaluates to an equivalent
390    # result to the desired Relay expression (and executes all effects in the right order).
391    #
392    # The second is a list of function definition *statements* defining thunks and other
393    # auxiliary functions needed in the translated AST object. The defs in the second object
394    # will always have unique names and will never perform any effects, so as long as they
395    # appear in the Python program before the first statement is executed, there should not
396    # be any problems.
397
398    def visit_var(self, var: Expr):
399        return (self.include_var(var, assign=False), [])
400
401    def visit_global_var(self, gvar: Expr):
402        # we don't need to add numbers to global var names because
403        # the *names* are checked for uniqueness in the mod
404        return (Name(str(gvar.name_hint), Load()), [])
405
406    def visit_let(self, letexp: Expr):
407        # To properly account for scoping and ensure that the entire node produces an expression,
408        # we translate the let binding as a function that we call with the value we intend to bind.
409        # Yes, this is somewhat ugly.
410        """
411        let var = value in body
412        =======================
413        def let_thunk(var):
414            return body
415        let_thunk(value)
416        """
417        bind_body, bind_defs = self.visit(letexp.body)
418
419        func_name = self.generate_function_name("_let_func")
420        binding_func = self.create_def(
421            func_name, [self.get_var_name(letexp.var)], bind_defs + [Return(bind_body)]
422        )
423
424        # we call the binding func with the intended value for the bound variable
425
426        # special case: if the value is a function literal, we must ensure it can be
427        # recursive by naming it after the var
428        if isinstance(letexp.value, Function):
429            value_def, value_name = self.convert_func_node(letexp.value, letexp.var)
430            return (
431                self.create_call(func_name, [Name(value_name, Load())]),
432                [value_def, binding_func],
433            )
434
435        value_body, value_defs = self.visit(letexp.value)
436        value_defs.append(binding_func)
437        binding_call = self.create_call(func_name, [value_body])
438        return (binding_call, value_defs)
439
440    def visit_tuple(self, tup: Expr):
441        fields, ret_defs = self.convert_fields(tup.fields)
442        fields = [ast.List(fields, Load())]
443        return (self.create_call("_container.tuple_object", fields), ret_defs)
444
445    def visit_tuple_getitem(self, tgi: Expr):
446        tup, tup_defs = self.visit(tgi.tuple_value)
447        ret = ast.Subscript(tup, ast.Index(Num(tgi.index)), Load())
448        return (ret, tup_defs)
449
450    def visit_if(self, if_block: Expr):
451        cond_body, cond_defs = self.visit(if_block.cond)
452        true_body, true_defs = self.visit(if_block.true_branch)
453        false_body, false_defs = self.visit(if_block.false_branch)
454
455        # need to get the value out of a NDArray to check the condition
456        # equvialent to: val.asnumpy()
457        cond_check = ast.Call(ast.Attribute(cond_body, "asnumpy", Load()), [], [])
458        ret = ast.IfExp(cond_check, true_body, false_body)
459        return (ret, cond_defs + true_defs + false_defs)
460
461    def visit_constant(self, constant: Expr):
462        """Proceeds by converting constant value to a numpy array
463        and converting it to the appropriate value in the generated
464        code (whether it be a Python scalar or a Numpy array)"""
465        value = constant.data.asnumpy()
466        const_expr = ast.Call(
467            ast.Attribute(Name("numpy", Load()), "array", Load()),
468            [self.parse_numpy_array(value)],
469            [ast.keyword("dtype", Str(constant.checked_type.dtype))],
470        )
471        return (self.create_call("nd.array", [const_expr]), [])
472
473    def visit_function(self, func: Expr):
474        # Python's lambdas are very restrictive, so we do "name" inline functions
475        converted_func, func_name = self.convert_func_node(func)
476        return (Name(func_name, Load()), [converted_func])
477
478    def visit_call(self, call: Expr):
479        """For calls, we must distinguish between ordinary functions,
480        operators, and constructor calls."""
481        func = call.op
482        fields, field_defs = self.convert_fields(call.args)
483
484        if isinstance(func, tvm.ir.Op):
485            raise Exception("Operators should have been lowered and eliminated")
486
487        if isinstance(func, relay.Constructor):
488            # produce a constructor value
489            return (
490                self.create_call(
491                    "ConstructorValue",
492                    [ast.Num(func.tag), ast.List(fields, Load()), NameConstant(None)],
493                ),
494                field_defs,
495            )
496
497        # lowered operator: generate a call to a function that gets the PackedFunc
498        # from TVM's registry
499        if isinstance(func, Function) and func.attrs and func.attrs.Primitive.value == 1:
500            op_call_def, op_call = self.create_op_call(func, call.args, fields)
501            return (op_call, field_defs + [op_call_def])
502
503        # ordinary function
504        converted_func, defs = self.visit(func)
505        defs += field_defs
506        return (ast.Call(converted_func, fields, []), defs)
507
508    def visit_ref_create(self, ref: Expr):
509        val, defs = self.visit(ref.value)
510        return (self.create_call("RefValue", [val]), defs)
511
512    def visit_ref_read(self, read: Expr):
513        ref, defs = self.visit(read.ref)
514        return (ast.Attribute(ref, "value", Load()), defs)
515
516    def visit_ref_write(self, write: Expr):
517        """For writing refs, we wrap the update in a thunk
518        (returning an empty tuple to match Relay's semantics)
519        that we execute at the right time. This ensures such assignments
520        can be properly nested, since assignments are statements
521        in Python but expressions in Relay"""
522        ref, ref_defs = self.visit(write.ref)
523        val, val_defs = self.visit(write.value)
524        thunk_name = self.generate_function_name("_ref_write_thunk")
525        thunk = self.create_def(
526            thunk_name,
527            [],
528            ref_defs
529            + val_defs
530            + [
531                Assign([ast.Attribute(ref, "value", Store())], val),
532                Return(self.create_call("_container.tuple_object", [])),
533            ],
534        )
535        return (self.create_call(thunk_name, []), [thunk])
536
537    def visit_match(self, match: Expr):
538        """For matches, we wrap the entire expression in a thunk
539        because it is easiest to implement them using if statements.
540        For each clause, we generate a function that checks if the
541        pattern matches. If yes, we call a function that assigns
542        the variables appropriately and invokes the clause body."""
543        data, defs = self.visit(match.data)
544        data_var = self.generate_var_name("_match_data")
545
546        # must ensure the data clause is executed exactly once
547        thunk_body = [Assign([Name(data_var, Store())], data)]
548        for clause in match.clauses:
549            check_expr = self.create_match_check(clause.lhs, Name(data_var, Load()))
550            body_def, body_name = self.create_match_clause_body(clause.lhs, clause.rhs)
551            defs.append(body_def)
552
553            # equiv: if check(data): return body(data)
554            thunk_body.append(
555                ast.If(
556                    check_expr, [Return(self.create_call(body_name, [Name(data_var, Load())]))], []
557                )
558            )
559
560        # finally if nothing matches we have a failed assert (should never happen)
561        thunk_body.append(ast.Assert(NameConstant(False), Str("Match was not exhaustive")))
562
563        thunk_name = self.generate_function_name("_match_thunk")
564        thunk_def = self.create_def(thunk_name, [], defs + thunk_body)
565        return (self.create_call(thunk_name, []), [thunk_def])
566
567    # these are both handled in the "call" case
568    def visit_constructor(self, _):
569        pass
570
571    def visit_op(self, _):
572        pass
573
574
575def to_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):
576    """Converts the given Relay expression into a Python script (as a Python AST object).
577    For easiest debugging, import the astor package and use to_source()."""
578    mod = mod if mod is not None else tvm.IRModule()
579    converter = PythonConverter(mod, target)
580    return converter.convert(expr)
581
582
583def run_as_python(expr: Expr, mod=None, target=tvm.target.Target("llvm")):
584    """Converts the given Relay expression into a Python script and
585    executes it."""
586    mod = mod if mod is not None else tvm.IRModule()
587    py_ast = to_python(expr, mod, target)
588    code = compile(py_ast, "<string>", "exec")
589    var_map = {OUTPUT_VAR_NAME: None}
590    # pylint: disable=exec-used
591    exec(code, var_map, var_map)
592    return var_map[OUTPUT_VAR_NAME]
593