1"""
2From NumbaPro
3
4"""
5
6from collections import namedtuple, OrderedDict
7import dis
8import inspect
9import itertools
10from types import CodeType, ModuleType
11
12from numba.core import errors, utils, serialize
13
14
15opcode_info = namedtuple('opcode_info', ['argsize'])
16
17# The following offset is used as a hack to inject a NOP at the start of the
18# bytecode. So that function starting with `while True` will not have block-0
19# as a jump target. The Lowerer puts argument initialization at block-0.
20_FIXED_OFFSET = 2
21
22
23def get_function_object(obj):
24    """
25    Objects that wraps function should provide a "__numba__" magic attribute
26    that contains a name of an attribute that contains the actual python
27    function object.
28    """
29    attr = getattr(obj, "__numba__", None)
30    if attr:
31        return getattr(obj, attr)
32    return obj
33
34
35def get_code_object(obj):
36    "Shamelessly borrowed from llpython"
37    return getattr(obj, '__code__', getattr(obj, 'func_code', None))
38
39
40def _as_opcodes(seq):
41    lst = []
42    for s in seq:
43        c = dis.opmap.get(s)
44        if c is not None:
45            lst.append(c)
46    return lst
47
48
49JREL_OPS = frozenset(dis.hasjrel)
50JABS_OPS = frozenset(dis.hasjabs)
51JUMP_OPS = JREL_OPS | JABS_OPS
52TERM_OPS = frozenset(_as_opcodes(['RETURN_VALUE', 'RAISE_VARARGS']))
53EXTENDED_ARG = dis.EXTENDED_ARG
54HAVE_ARGUMENT = dis.HAVE_ARGUMENT
55
56
57class ByteCodeInst(object):
58    '''
59    Attributes
60    ----------
61    - offset:
62        byte offset of opcode
63    - opcode:
64        opcode integer value
65    - arg:
66        instruction arg
67    - lineno:
68        -1 means unknown
69    '''
70    __slots__ = 'offset', 'next', 'opcode', 'opname', 'arg', 'lineno'
71
72    def __init__(self, offset, opcode, arg, nextoffset):
73        self.offset = offset
74        self.next = nextoffset
75        self.opcode = opcode
76        self.opname = dis.opname[opcode]
77        self.arg = arg
78        self.lineno = -1  # unknown line number
79
80    @property
81    def is_jump(self):
82        return self.opcode in JUMP_OPS
83
84    @property
85    def is_terminator(self):
86        return self.opcode in TERM_OPS
87
88    def get_jump_target(self):
89        assert self.is_jump
90        if self.opcode in JREL_OPS:
91            return self.next + self.arg
92        else:
93            assert self.opcode in JABS_OPS
94            return self.arg
95
96    def __repr__(self):
97        return '%s(arg=%s, lineno=%d)' % (self.opname, self.arg, self.lineno)
98
99    @property
100    def block_effect(self):
101        """Effect of the block stack
102        Returns +1 (push), 0 (none) or -1 (pop)
103        """
104        if self.opname.startswith('SETUP_'):
105            return 1
106        elif self.opname == 'POP_BLOCK':
107            return -1
108        else:
109            return 0
110
111
112CODE_LEN = 1
113ARG_LEN = 1
114NO_ARG_LEN = 1
115
116OPCODE_NOP = dis.opname.index('NOP')
117
118
119# Adapted from Lib/dis.py
120def _unpack_opargs(code):
121    """
122    Returns a 4-int-tuple of
123    (bytecode offset, opcode, argument, offset of next bytecode).
124    """
125    extended_arg = 0
126    n = len(code)
127    offset = i = 0
128    while i < n:
129        op = code[i]
130        i += CODE_LEN
131        if op >= HAVE_ARGUMENT:
132            arg = code[i] | extended_arg
133            for j in range(ARG_LEN):
134                arg |= code[i + j] << (8 * j)
135            i += ARG_LEN
136            if op == EXTENDED_ARG:
137                extended_arg = arg << 8 * ARG_LEN
138                continue
139        else:
140            arg = None
141            i += NO_ARG_LEN
142
143        extended_arg = 0
144        yield (offset, op, arg, i)
145        offset = i  # Mark inst offset at first extended
146
147
148def _patched_opargs(bc_stream):
149    """Patch the bytecode stream.
150
151    - Adds a NOP bytecode at the start to avoid jump target being at the entry.
152    """
153    # Injected NOP
154    yield (0, OPCODE_NOP, None, _FIXED_OFFSET)
155    # Adjust bytecode offset for the rest of the stream
156    for offset, opcode, arg, nextoffset in bc_stream:
157        # If the opcode has an absolute jump target, adjust it.
158        if opcode in JABS_OPS:
159            arg += _FIXED_OFFSET
160        yield offset + _FIXED_OFFSET, opcode, arg, nextoffset + _FIXED_OFFSET
161
162
163class ByteCodeIter(object):
164    def __init__(self, code):
165        self.code = code
166        self.iter = iter(_patched_opargs(_unpack_opargs(self.code.co_code)))
167
168    def __iter__(self):
169        return self
170
171    def _fetch_opcode(self):
172        return next(self.iter)
173
174    def next(self):
175        offset, opcode, arg, nextoffset = self._fetch_opcode()
176        return offset, ByteCodeInst(offset=offset, opcode=opcode, arg=arg,
177                                    nextoffset=nextoffset)
178
179    __next__ = next
180
181    def read_arg(self, size):
182        buf = 0
183        for i in range(size):
184            _offset, byte = next(self.iter)
185            buf |= byte << (8 * i)
186        return buf
187
188
189class ByteCode(object):
190    """
191    The decoded bytecode of a function, and related information.
192    """
193    __slots__ = ('func_id', 'co_names', 'co_varnames', 'co_consts',
194                 'co_cellvars', 'co_freevars', 'table', 'labels')
195
196    def __init__(self, func_id):
197        code = func_id.code
198
199        labels = set(x + _FIXED_OFFSET for x in dis.findlabels(code.co_code))
200        labels.add(0)
201
202        # A map of {offset: ByteCodeInst}
203        table = OrderedDict(ByteCodeIter(code))
204        self._compute_lineno(table, code)
205
206        self.func_id = func_id
207        self.co_names = code.co_names
208        self.co_varnames = code.co_varnames
209        self.co_consts = code.co_consts
210        self.co_cellvars = code.co_cellvars
211        self.co_freevars = code.co_freevars
212        self.table = table
213        self.labels = sorted(labels)
214
215    @classmethod
216    def _compute_lineno(cls, table, code):
217        """
218        Compute the line numbers for all bytecode instructions.
219        """
220        for offset, lineno in dis.findlinestarts(code):
221            adj_offset = offset + _FIXED_OFFSET
222            if adj_offset in table:
223                table[adj_offset].lineno = lineno
224        # Assign unfilled lineno
225        # Start with first bytecode's lineno
226        known = table[_FIXED_OFFSET].lineno
227        for inst in table.values():
228            if inst.lineno >= 0:
229                known = inst.lineno
230            else:
231                inst.lineno = known
232        return table
233
234    def __iter__(self):
235        return utils.itervalues(self.table)
236
237    def __getitem__(self, offset):
238        return self.table[offset]
239
240    def __contains__(self, offset):
241        return offset in self.table
242
243    def dump(self):
244        def label_marker(i):
245            if i[1].offset in self.labels:
246                return '>'
247            else:
248                return ' '
249
250        return '\n'.join('%s %10s\t%s' % ((label_marker(i),) + i)
251                         for i in utils.iteritems(self.table))
252
253    @classmethod
254    def _compute_used_globals(cls, func, table, co_consts, co_names):
255        """
256        Compute the globals used by the function with the given
257        bytecode table.
258        """
259        d = {}
260        globs = func.__globals__
261        builtins = globs.get('__builtins__', utils.builtins)
262        if isinstance(builtins, ModuleType):
263            builtins = builtins.__dict__
264        # Look for LOAD_GLOBALs in the bytecode
265        for inst in table.values():
266            if inst.opname == 'LOAD_GLOBAL':
267                name = co_names[inst.arg]
268                if name not in d:
269                    try:
270                        value = globs[name]
271                    except KeyError:
272                        value = builtins[name]
273                    d[name] = value
274        # Add globals used by any nested code object
275        for co in co_consts:
276            if isinstance(co, CodeType):
277                subtable = OrderedDict(ByteCodeIter(co))
278                d.update(cls._compute_used_globals(func, subtable,
279                                                   co.co_consts, co.co_names))
280        return d
281
282    def get_used_globals(self):
283        """
284        Get a {name: value} map of the globals used by this code
285        object and any nested code objects.
286        """
287        return self._compute_used_globals(self.func_id.func, self.table,
288                                          self.co_consts, self.co_names)
289
290
291class FunctionIdentity(serialize.ReduceMixin):
292    """
293    A function's identity and metadata.
294
295    Note this typically represents a function whose bytecode is
296    being compiled, not necessarily the top-level user function
297    (the two might be distinct, e.g. in the `@generated_jit` case).
298    """
299    _unique_ids = itertools.count(1)
300
301    @classmethod
302    def from_function(cls, pyfunc):
303        """
304        Create the FunctionIdentity of the given function.
305        """
306        func = get_function_object(pyfunc)
307        code = get_code_object(func)
308        pysig = utils.pysignature(func)
309        if not code:
310            raise errors.ByteCodeSupportError(
311                "%s does not provide its bytecode" % func)
312
313        try:
314            func_qualname = func.__qualname__
315        except AttributeError:
316            func_qualname = func.__name__
317
318        self = cls()
319        self.func = func
320        self.func_qualname = func_qualname
321        self.func_name = func_qualname.split('.')[-1]
322        self.code = code
323        self.module = inspect.getmodule(func)
324        self.modname = (utils._dynamic_modname
325                        if self.module is None
326                        else self.module.__name__)
327        self.is_generator = inspect.isgeneratorfunction(func)
328        self.pysig = pysig
329        self.filename = code.co_filename
330        self.firstlineno = code.co_firstlineno
331        self.arg_count = len(pysig.parameters)
332        self.arg_names = list(pysig.parameters)
333
334        # Even the same function definition can be compiled into
335        # several different function objects with distinct closure
336        # variables, so we make sure to disambiguate using an unique id.
337        uid = next(cls._unique_ids)
338        self.unique_name = '{}${}'.format(self.func_qualname, uid)
339
340        return self
341
342    def derive(self):
343        """Copy the object and increment the unique counter.
344        """
345        return self.from_function(self.func)
346
347    def _reduce_states(self):
348        """
349        NOTE: part of ReduceMixin protocol
350        """
351        return dict(pyfunc=self.func)
352
353    @classmethod
354    def _rebuild(cls, pyfunc):
355        """
356        NOTE: part of ReduceMixin protocol
357        """
358        return cls.from_function(pyfunc)
359