1from numba.core import errors, ir
2from numba.core.rewrites import register_rewrite, Rewrite
3
4
5class Macro(object):
6    '''
7    A macro object is expanded to a function call
8
9    Args
10    ----
11    name: str
12        Name of this Macro
13    func: function
14        Function that evaluates the macro expansion.
15    callable: bool
16        True if the macro is callable from Python code
17        (``func`` is then a Python callable returning the desired IR node).
18        False if the macro is not callable
19        (``func`` is then the name of a backend-specific function name
20         specifying the function to call at runtime).
21    argnames: list
22        If ``callable`` is True, this holds a list of the names of arguments
23        to the function.
24    '''
25
26    __slots__ = 'name', 'func', 'callable', 'argnames'
27
28    def __init__(self, name, func, callable=False, argnames=None):
29        self.name = name
30        self.func = func
31        self.callable = callable
32        self.argnames = argnames
33
34    def __repr__(self):
35        return '<macro %s -> %s>' % (self.name, self.func)
36
37
38@register_rewrite('before-inference')
39class ExpandMacros(Rewrite):
40    """
41    Expand lookups and calls of Macro objects.
42    """
43
44    def match(self, func_ir, block, typemap, calltypes):
45        """
46        Look for potential macros for expand and store their expansions.
47        """
48        self.block = block
49        self.rewrites = rewrites = {}
50
51        for inst in block.body:
52            if isinstance(inst, ir.Assign):
53                rhs = inst.value
54                if (isinstance(rhs, ir.Expr) and rhs.op == 'call'
55                    and isinstance(rhs.func, ir.Var)):
56                    # Is it a callable macro?
57                    try:
58                        const = func_ir.infer_constant(rhs.func)
59                    except errors.ConstantInferenceError:
60                        continue
61                    if isinstance(const, Macro):
62                        assert const.callable
63                        new_expr = self._expand_callable_macro(func_ir, rhs,
64                                                               const, rhs.loc)
65                        rewrites[rhs] = new_expr
66
67                elif isinstance(rhs, ir.Expr) and rhs.op == 'getattr':
68                    # Is it a non-callable macro looked up as a constant attribute?
69                    try:
70                        const = func_ir.infer_constant(inst.target)
71                    except errors.ConstantInferenceError:
72                        continue
73                    if isinstance(const, Macro) and not const.callable:
74                        new_expr = self._expand_non_callable_macro(const, rhs.loc)
75                        rewrites[rhs] = new_expr
76
77        return len(rewrites) > 0
78
79    def _expand_non_callable_macro(self, macro, loc):
80        """
81        Return the IR expression of expanding the non-callable macro.
82        """
83        intr = ir.Intrinsic(macro.name, macro.func, args=())
84        new_expr = ir.Expr.call(func=intr, args=(),
85                                kws=(), loc=loc)
86        return new_expr
87
88    def _expand_callable_macro(self, func_ir, call, macro, loc):
89        """
90        Return the IR expression of expanding the macro call.
91        """
92        assert macro.callable
93
94        # Resolve all macro arguments as constants, or fail
95        args = [func_ir.infer_constant(arg.name) for arg in call.args]
96        kws = {}
97        for k, v in call.kws:
98            try:
99                kws[k] = func_ir.infer_constant(v)
100            except errors.ConstantInferenceError:
101                msg = "Argument {name!r} must be a " \
102                      "constant at {loc}".format(name=k,
103                                                 loc=loc)
104                raise ValueError(msg)
105
106        try:
107            result = macro.func(*args, **kws)
108        except Exception as e:
109            msg = str(e)
110            headfmt = "Macro expansion failed at {line}"
111            head = headfmt.format(line=loc)
112            newmsg = "{0}:\n{1}".format(head, msg)
113            raise errors.MacroError(newmsg)
114
115        assert result is not None
116
117        result.loc = call.loc
118        new_expr = ir.Expr.call(func=result, args=call.args,
119                                kws=call.kws, loc=loc)
120        return new_expr
121
122    def apply(self):
123        """
124        Apply the expansions computed in .match().
125        """
126        block = self.block
127        rewrites = self.rewrites
128        for inst in block.body:
129            if isinstance(inst, ir.Assign) and inst.value in rewrites:
130                inst.value = rewrites[inst.value]
131        return block
132