1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Utility for converting Relay code into a Python script with equivalent semantics"""
18import ast
19from ast import alias, Assign, Load, Name, NameConstant, Num, Return, Store, Str
20import re
21
22import tvm
23from tvm import relay
24from tvm.relay.adt import Pattern
25from tvm.relay.backend import compile_engine
26from tvm.relay.expr import Expr, Function, GlobalVar, Var
27from tvm.relay.expr_functor import ExprFunctor
28
29OUTPUT_VAR_NAME = '_py_out'
30
31# corresponds to:
32#     import numpy
33#     import tvm
34#     from tvm import relay
35#     from tvm.relay.backend.interpreter import RefValue, TupleValue, TensorValue, ConstructorValue
36PROLOGUE = [
37    ast.Import([alias('numpy', None)]),
38    ast.Import([alias('tvm', None)]),
39    ast.ImportFrom('tvm', [alias('relay', None)], 0),
40    ast.ImportFrom('tvm.relay.backend.interpreter',
41                   [alias('RefValue', None),
42                    alias('TupleValue', None),
43                    alias('TensorValue', None),
44                    alias('ConstructorValue', None)],
45                   0)
46]
47
48class PythonConverter(ExprFunctor):
49    """Functor for translating Relay programs into Python ASTs."""
50
51    def __init__(self, mod, target) -> None:
52        super().__init__()
53        self.mod = mod
54        self.tgt = target
55        self.engine = compile_engine.get()
56        self.fun_no = 0
57        self.var_no = 0
58        self.var_map = {}
59
60
61    def convert(self, prog: Expr):
62        """This method converts the passed Relay expression into a Python
63        AST object with equivalent semantics.
64
65        The Python AST can be executed using exec(); it can be turned
66        into text and inspected using astor.
67        """
68        optimized = self.optimize(prog)
69
70        # start with conversion prelude (imports) and convert global defs
71        body = []
72        body += PROLOGUE
73        body += self.convert_module()
74
75        prog_body, extra_defs = self.visit(optimized)
76        body += extra_defs
77
78        # we finally must assign the final expression to the output var
79        # so it can be read after running EXEC
80        body.append(Assign([Name(OUTPUT_VAR_NAME, Store())], prog_body))
81
82        return ast.fix_missing_locations(ast.Module(body=body))
83
84
85    def optimize(self, prog: Expr):
86        """Performs optimizations necessary to be able to generate code for prog."""
87        # unwrap tuple wrappers (some op calls produce them)
88        unwrapped = prog.astuple() if isinstance(prog, relay.TupleWrapper) else prog
89        assert relay.analysis.well_formed(unwrapped)
90        mod = self.mod.from_expr(unwrapped, self.mod.functions, self.mod.type_definitions)
91
92        # necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
93        # and fusion (to get primitive functions)
94        opts = relay.transform.Sequential([relay.transform.SimplifyInference(),
95                                           relay.transform.FuseOps(fuse_opt_level=0)])
96        mod = opts(mod)
97        optimized = mod['main']
98        return optimized if isinstance(unwrapped, Function) else optimized.body
99
100
101    def sanitize(self, name: str) -> str:
102        """Removes any invalid characters (only underscores, numbers, and letters permitted)
103        from the given name. Since we append a number and underscore to var names anyway,
104        it doesn't matter if the name is the empty string."""
105        return re.sub(r'\W', '', name)
106
107
108    def generate_var_name(self, name_hint: str) -> str:
109        """Generates a unique variable name starting from the hint."""
110        name = '{}_var_{}'.format(self.sanitize(name_hint), self.var_no)
111        self.var_no += 1
112        return name
113
114
115    def generate_function_name(self, name_hint: str) -> str:
116        """Generates a unique function name starting from the hint."""
117        name = '{}_fun_{}'.format(self.sanitize(name_hint), self.fun_no)
118        self.fun_no += 1
119        return name
120
121
122    def get_var_name(self, var: Expr) -> str:
123        """Returns the var name for the given Realy variable."""
124        if var in self.var_map:
125            return self.var_map[var]
126        name = self.generate_var_name(var.name_hint)
127        self.var_map[var] = name
128        return name
129
130
131    def include_var(self, var: Expr, assign=False):
132        """Returns a variable AST node for the given Relay var depending on
133        whether it must appear in an assignment or not."""
134        name = self.get_var_name(var)
135        return Name(name, Store() if assign else Load())
136
137
138    def parse_name(self, name: str):
139        """Given the name of a Python method with dots (e.g., 'relay.var'),
140        returns an appropriate AST object corresponding to that name."""
141        attributes = name.split('.')
142        ret = Name(attributes[0], Load())
143        for i in range(len(attributes) - 1):
144            ret = ast.Attribute(ret, attributes[i+1], Load())
145        return ret
146
147
148    def parse_numpy_array(self, arr):
149        """Given a Numpy array, produces an appropriate Python array
150        or numerical literal representing its contents."""
151        parse_single = lambda i: NameConstant(i) if isinstance(i, bool) else Num(i)
152        if arr.ndim == 0:
153            return parse_single(arr.item())
154        if arr.ndim == 1:
155            return ast.List([parse_single(i.item()) for i in arr], Load())
156
157        elts = []
158        for row in arr:
159            elts.append(self.parse_numpy_array(row))
160        return ast.List(elts, Load())
161
162
163    def convert_fields(self, fields: [Expr]):
164        """Given a list of call args or tuple fields, converts
165        each and returns their ASTs and their defs lists (in order)."""
166        bodies = []
167        defs = []
168        for field in fields:
169            member_body, member_defs = self.visit(field)
170            bodies.append(member_body)
171            defs += member_defs
172        return (bodies, defs)
173
174
175    def convert_to_thunk(self, name_hint: str, expr: Expr):
176        """Wraps the passed expression in a thunk."""
177        body, defs = self.visit(expr)
178        thunk_name = self.generate_function_name(name_hint)
179        thunk = self.create_def(thunk_name, [], defs + [Return(body)])
180        return (thunk, thunk_name)
181
182
183    def convert_func_node(self, func: Function, name_var=None):
184        """Converts the given Relay function into a Python function, with
185        special for named functions (locally or globally)"""
186        if name_var is None:
187            func_name = self.generate_function_name('_anon_func')
188        if isinstance(name_var, GlobalVar):
189            func_name = name_var.name_hint
190        if isinstance(name_var, Var):
191            func_name = self.get_var_name(name_var)
192
193        var_names = [self.get_var_name(var) for var in func.params]
194        body, defs = self.visit(func.body)
195        ret = self.create_def(func_name, var_names, defs + [Return(body)])
196        return (ret, func_name)
197
198
199    def convert_module(self):
200        """Converts all the global functions defined in the module and returns
201        them as a list of definitions"""
202        defs = []
203        for var, func in self.mod.functions.items():
204            # optimize the definition so any operators used are lowered
205            opt_func = self.optimize(func)
206            try:
207                converted_func, _ = self.convert_func_node(opt_func, var)
208                defs.append(converted_func)
209            except TypeError:
210                # TODO(wweic): fix conversion for Any
211                pass
212        return defs
213
214
215    def create_call(self, func_name: str, arguments):
216        """Creates a simple function call."""
217        return ast.Call(self.parse_name(func_name), arguments, [])
218
219
220    def create_def(self, func_name: str, arguments: [str], body):
221        """Wrapper over function definition AST node, whose constructor is inconvenient."""
222        return ast.FunctionDef(
223            func_name,
224            ast.arguments([ast.arg(argument, None)
225                           for argument in arguments],
226                          None, [], [], None, []),
227            body, [], None)
228
229
230    def create_op_call(self, op: Function, relay_args, py_args):
231        """Lowers the passed primitive function, registers it in TVM's
232        global compiler, and produces a call to the lowered function in
233        the generated Python code."""
234
235        # compile the function and register globally
236        cc_key = compile_engine.CCacheKey(op, self.tgt)
237        func_hash = relay.analysis.structural_hash(op)
238        op_name = '_lowered_op_{}'.format(func_hash)
239        if not tvm.get_global_func(op_name, allow_missing=True):
240            jitted = self.engine.jit(cc_key, self.tgt)
241            tvm.register_func(op_name, jitted)
242
243        def convert_input(py_input, arg_type):
244            """Use the types of the function arguments to determine whether we expect
245               a tensor or tuple (returns list of inputs to the lowered op call)"""
246            # equivalent: input.data
247            if isinstance(arg_type, relay.TensorType):
248                return [ast.Attribute(py_input, 'data', Load())]
249            assert isinstance(arg_type, relay.TupleType)
250            # convert each input.fields[i]
251            ret = []
252            for i in range(len(arg_type.fields)):
253                ret += convert_input(
254                    ast.Subscript(
255                        ast.Attribute(py_input, 'fields', Load()),
256                        ast.Index(Num(i)), Load()),
257                    arg_type.fields[i])
258            return ret
259
260        def convert_output(ret_type):
261            """Use the function return type to produce auxiliary variables to store outputs.
262            Returns ([assignments of output vars], [extra arguments to pass to op call],
263            expression collecting output)"""
264            if isinstance(ret_type, relay.TensorType):
265                output_var_name = self.generate_var_name('_out')
266                output_var = Name(output_var_name, Load())
267                shape = ast.Tuple([Num(dim) for dim in ret_type.concrete_shape], Load())
268                # create a new TensorValue of the right shape and dtype
269                assign_output = Assign(
270                    [Name(output_var_name, Store())],
271                    self.create_call('TensorValue', [
272                        self.create_call('numpy.empty', [shape, Str(ret_type.dtype)])
273                    ]))
274                # we pass the data field as an argument
275                extra_arg = ast.Attribute(output_var, 'data', Load())
276                return ([assign_output], [extra_arg], output_var)
277            assert isinstance(ret_type, relay.TupleType)
278            assignments = []
279            extra_args = []
280            fields = []
281            for t in ret_type.fields:
282                inner_assignments, inner_args, inner_output = convert_output(t)
283                assignments += inner_assignments
284                extra_args += inner_args
285                fields.append(inner_output)
286            return (assignments, extra_args, self.create_call('TupleValue', fields))
287
288        # create a function to wrap the call of the lowered op and return
289        # a call to that function
290        wrap_name = self.generate_function_name('_{}_wrapper'.format(op_name))
291        wrap_args = [self.generate_var_name('_arg_{}'.format(i)) for i in range(len(py_args))]
292
293        inner_call_args = []
294        for i in range(len(py_args)):
295            inner_call_args += convert_input(Name(wrap_args[i], Load()),
296                                             relay_args[i].checked_type)
297        output_assignments, aux_args, output = convert_output(op.checked_type.ret_type)
298        # equiv: _op = tvm.get_global_func(op_name)
299        op_var = self.generate_var_name('_op')
300        op_call = self.create_call('tvm.get_global_func', [Str(op_name)])
301        op_assign = Assign([Name(op_var, Store())], op_call)
302        # equiv: _op(args)
303        inner_call = self.create_call(op_var, inner_call_args + aux_args)
304        body = output_assignments + [op_assign, ast.Expr(inner_call), Return(output)]
305        wrap_def = self.create_def(wrap_name, wrap_args, body)
306        return wrap_def, self.create_call(wrap_name, py_args)
307
308
309    def create_match_check(self, pattern: Pattern, data):
310        """Given an ADT match pattern and a (Python) expression pointing to
311        an ADT value, this generates a Python expression that checks if the
312        ADT value matches the given pattern (returning True or False)."""
313
314        # wildcard or var match everything
315        if isinstance(pattern, (relay.PatternWildcard, relay.PatternVar)):
316            return NameConstant(True)
317
318        conds = []
319
320        if isinstance(pattern, relay.PatternConstructor):
321            # constructor patterns check whether the constructors match
322            # and also the matches of any nested patterns
323
324            # equiv: (arg.tag == patern_constructor.tag)
325            conds.append(ast.Compare(ast.Attribute(data, 'tag', Load()),
326                                     [ast.Eq()],
327                                     [ast.Num(pattern.constructor.tag)]))
328
329        assert isinstance(pattern, (relay.PatternConstructor, relay.PatternTuple))
330        # now check for any nested patterns
331        for i in range(len(pattern.patterns)):
332            nested_pat = pattern.patterns[i]
333            # can safely skip var or wildcard patterns: they will
334            # never cause a check to fail
335            if not isinstance(nested_pat, relay.PatternConstructor):
336                continue
337
338            # index into the value corresponding to the subpattern
339            field_index = ast.Subscript(ast.Attribute(data, 'fields', Load()),
340                                        ast.Index(Num(i)), Load())
341            conds.append(self.create_match_check(nested_pat, field_index))
342
343        # if we do not need to check nested pattern, just return the single check
344        if len(conds) == 1:
345            return conds[0]
346        # otherwise AND together any nested checks
347        return ast.BoolOp(ast.And(), conds)
348
349
350    def create_match_clause_body(self, pattern: Pattern, body: Expr):
351        """Given a match clause pattern and a clause body,
352        generates a Python function that when called with an ADT
353        that matches the pattern, returns the result of evaluating
354        the clause body. This function returns a function definition
355        and the name of the generated function."""
356
357        def collect_var_assignments(pat, val):
358            """This helper function ensures that the pattern is used to
359            properly assign all subfields of the given AST for use
360            in the clause body
361
362            E.g., for PatternConstructor(A, PatternVar(v), PatternWildcard(),
363            PatternConstructor(B, PatternVar(w)))
364            we would want to have
365            v = a.fields[0]
366            w = a.fields[2].fields[0]
367            """
368            if isinstance(pat, relay.PatternWildcard):
369                return []
370            if isinstance(pat, relay.PatternVar):
371                return [Assign([self.include_var(pat.var, assign=True)], val)]
372            # constructor pattern: assign each field of the value
373            # based on subpatterns
374            assignments = []
375            for i in range(len(pat.patterns)):
376                # we want the assignments for val.fields[i]
377                field = ast.Subscript(ast.Attribute(val, 'fields', Load()),
378                                      ast.Index(Num(i)), Load())
379                assignments += collect_var_assignments(pat.patterns[i], field)
380            return assignments
381
382        func_name = self.generate_function_name('_match_clause_body')
383        arg_name = self.generate_var_name('_match_clause_body')
384
385        clause_body, defs = self.visit(body)
386        assignments = collect_var_assignments(pattern, Name(arg_name, Load()))
387
388        func_def = self.create_def(func_name, [arg_name],
389                                   defs + assignments + [Return(clause_body)])
390        return (func_def, func_name)
391
392
393    # Convention for the expr visitor: Each visit function returns a tuple of two members.
394    #
395    # The first is a Python AST comprised of a single *expression* that evaluates to an equivalent
396    # result to the desired Relay expression (and executes all effects in the right order).
397    #
398    # The second is a list of function definition *statements* defining thunks and other
399    # auxiliary functions needed in the translated AST object. The defs in the second object
400    # will always have unique names and will never perform any effects, so as long as they
401    # appear in the Python program before the first statement is executed, there should not
402    # be any problems.
403
404    def visit_var(self, var: Expr):
405        return (self.include_var(var, assign=False), [])
406
407
408    def visit_global_var(self, gvar: Expr):
409        # we don't need to add numbers to global var names because
410        # the *names* are checked for uniqueness in the mod
411        return (Name(gvar.name_hint, Load()), [])
412
413
414    def visit_let(self, letexp: Expr):
415        # To properly account for scoping and ensure that the entire node produces an expression,
416        # we translate the let binding as a function that we call with the value we intend to bind.
417        # Yes, this is somewhat ugly.
418        """
419        let var = value in body
420        =======================
421        def let_thunk(var):
422            return body
423        let_thunk(value)
424        """
425        bind_body, bind_defs = self.visit(letexp.body)
426
427        func_name = self.generate_function_name('_let_func')
428        binding_func = self.create_def(func_name, [self.get_var_name(letexp.var)],
429                                       bind_defs + [Return(bind_body)])
430
431        # we call the binding func with the intended value for the bound variable
432
433        # special case: if the value is a function literal, we must ensure it can be
434        # recursive by naming it after the var
435        if isinstance(letexp.value, Function):
436            value_def, value_name = self.convert_func_node(letexp.value, letexp.var)
437            return (self.create_call(func_name, [Name(value_name, Load())]),
438                    [value_def, binding_func])
439
440        value_body, value_defs = self.visit(letexp.value)
441        value_defs.append(binding_func)
442        binding_call = self.create_call(func_name, [value_body])
443        return (binding_call, value_defs)
444
445
446    def visit_tuple(self, tup: Expr):
447        fields, ret_defs = self.convert_fields(tup.fields)
448        return (self.create_call('TupleValue', fields), ret_defs)
449
450
451    def visit_tuple_getitem(self, tgi: Expr):
452        tup, tup_defs = self.visit(tgi.tuple_value)
453        ret = ast.Subscript(tup, ast.Index(Num(tgi.index)), Load())
454        return (ret, tup_defs)
455
456
457    def visit_if(self, if_block: Expr):
458        cond_body, cond_defs = self.visit(if_block.cond)
459        true_body, true_defs = self.visit(if_block.true_branch)
460        false_body, false_defs = self.visit(if_block.false_branch)
461
462        # need to get the value out of a TensorValue to check the condition
463        # equvialent to: val.asnumpy()
464        cond_check = ast.Call(ast.Attribute(cond_body, 'asnumpy', Load()), [], [])
465        ret = ast.IfExp(cond_check, true_body, false_body)
466        return (ret, cond_defs + true_defs + false_defs)
467
468
469    def visit_constant(self, constant: Expr):
470        """Proceeds by converting constant value to a numpy array
471        and converting it to the appropriate value in the generated
472        code (whether it be a Python scalar or a Numpy array)"""
473        value = constant.data.asnumpy()
474        const_expr = ast.Call(ast.Attribute(Name('numpy', Load()), 'array', Load()),
475                              [self.parse_numpy_array(value)],
476                              [ast.keyword('dtype', Str(constant.checked_type.dtype))])
477        return (self.create_call('TensorValue', [const_expr]), [])
478
479
480    def visit_function(self, func: Expr):
481        # Python's lambdas are very restrictive, so we do "name" inline functions
482        converted_func, func_name = self.convert_func_node(func)
483        return (Name(func_name, Load()), [converted_func])
484
485
486    def visit_call(self, call: Expr):
487        """For calls, we must distinguish between ordinary functions,
488        operators, and constructor calls."""
489        func = call.op
490        fields, field_defs = self.convert_fields(call.args)
491
492        if isinstance(func, relay.Op):
493            raise Exception('Operators should have been lowered and eliminated')
494
495        if isinstance(func, relay.Constructor):
496            # produce a constructor value
497            return (self.create_call('ConstructorValue',
498                                     [ast.Num(func.tag),
499                                      ast.List(fields, Load()),
500                                      NameConstant(None)]),
501                    field_defs)
502
503        # lowered operator: generate a call to a function that gets the PackedFunc
504        # from TVM's registry
505        if isinstance(func, Function) and func.attrs and func.attrs.Primitive.value == 1:
506            op_call_def, op_call = self.create_op_call(func, call.args, fields)
507            return (op_call, field_defs + [op_call_def])
508
509        # ordinary function
510        converted_func, defs = self.visit(func)
511        defs += field_defs
512        return (ast.Call(converted_func, fields, []), defs)
513
514
515    def visit_ref_create(self, ref: Expr):
516        val, defs = self.visit(ref.value)
517        return (self.create_call('RefValue', [val]), defs)
518
519
520    def visit_ref_read(self, read: Expr):
521        ref, defs = self.visit(read.ref)
522        return (ast.Attribute(ref, 'value', Load()), defs)
523
524
525    def visit_ref_write(self, write: Expr):
526        """For writing refs, we wrap the update in a thunk
527        (returning an empty tuple to match Relay's semantics)
528        that we execute at the right time. This ensures such assignments
529        can be properly nested, since assignments are statements
530        in Python but expressions in Relay"""
531        ref, ref_defs = self.visit(write.ref)
532        val, val_defs = self.visit(write.value)
533        thunk_name = self.generate_function_name('_ref_write_thunk')
534        thunk = self.create_def(
535            thunk_name, [],
536            ref_defs + val_defs + [
537                Assign([ast.Attribute(ref, 'value', Store())], val),
538                Return(self.create_call('TupleValue', []))
539            ])
540        return (self.create_call(thunk_name, []), [thunk])
541
542
543    def visit_match(self, match: Expr):
544        """For matches, we wrap the entire expression in a thunk
545        because it is easiest to implement them using if statements.
546        For each clause, we generate a function that checks if the
547        pattern matches. If yes, we call a function that assigns
548        the variables appropriately and invokes the clause body."""
549        data, defs = self.visit(match.data)
550        data_var = self.generate_var_name('_match_data')
551
552        # must ensure the data clause is executed exactly once
553        thunk_body = [Assign([Name(data_var, Store())], data)]
554        for clause in match.clauses:
555            check_expr = self.create_match_check(clause.lhs, Name(data_var, Load()))
556            body_def, body_name = self.create_match_clause_body(clause.lhs, clause.rhs)
557            defs.append(body_def)
558
559            # equiv: if check(data): return body(data)
560            thunk_body.append(ast.If(
561                check_expr,
562                [Return(self.create_call(body_name, [Name(data_var, Load())]))],
563                []
564            ))
565
566        # finally if nothing matches we have a failed assert (should never happen)
567        thunk_body.append(ast.Assert(NameConstant(False), Str('Match was not exhaustive')))
568
569        thunk_name = self.generate_function_name('_match_thunk')
570        thunk_def = self.create_def(thunk_name, [], defs + thunk_body)
571        return (self.create_call(thunk_name, []), [thunk_def])
572
573
574    # these are both handled in the "call" case
575    def visit_constructor(self, _):
576        pass
577    def visit_op(self, _):
578        pass
579
580
581def to_python(expr: Expr, mod=None, target=tvm.target.create('llvm')):
582    """Converts the given Relay expression into a Python script (as a Python AST object).
583    For easiest debugging, import the astor package and use to_source()."""
584    mod = mod if mod is not None else relay.Module()
585    converter = PythonConverter(mod, target)
586    return converter.convert(expr)
587
588
589def run_as_python(expr: Expr, mod=None, target=tvm.target.create('llvm')):
590    """Converts the given Relay expression into a Python script and
591    executes it."""
592    mod = mod if mod is not None else relay.Module()
593    py_ast = to_python(expr, mod, target)
594    code = compile(py_ast, '<string>', 'exec')
595    var_map = {
596        OUTPUT_VAR_NAME : None
597    }
598    #pylint: disable=exec-used
599    exec(code, var_map, var_map)
600    return var_map[OUTPUT_VAR_NAME]
601