1""" 2From NumbaPro 3 4""" 5 6from collections import namedtuple, OrderedDict 7import dis 8import inspect 9import itertools 10from types import CodeType, ModuleType 11 12from numba.core import errors, utils, serialize 13 14 15opcode_info = namedtuple('opcode_info', ['argsize']) 16 17# The following offset is used as a hack to inject a NOP at the start of the 18# bytecode. So that function starting with `while True` will not have block-0 19# as a jump target. The Lowerer puts argument initialization at block-0. 20_FIXED_OFFSET = 2 21 22 23def get_function_object(obj): 24 """ 25 Objects that wraps function should provide a "__numba__" magic attribute 26 that contains a name of an attribute that contains the actual python 27 function object. 28 """ 29 attr = getattr(obj, "__numba__", None) 30 if attr: 31 return getattr(obj, attr) 32 return obj 33 34 35def get_code_object(obj): 36 "Shamelessly borrowed from llpython" 37 return getattr(obj, '__code__', getattr(obj, 'func_code', None)) 38 39 40def _as_opcodes(seq): 41 lst = [] 42 for s in seq: 43 c = dis.opmap.get(s) 44 if c is not None: 45 lst.append(c) 46 return lst 47 48 49JREL_OPS = frozenset(dis.hasjrel) 50JABS_OPS = frozenset(dis.hasjabs) 51JUMP_OPS = JREL_OPS | JABS_OPS 52TERM_OPS = frozenset(_as_opcodes(['RETURN_VALUE', 'RAISE_VARARGS'])) 53EXTENDED_ARG = dis.EXTENDED_ARG 54HAVE_ARGUMENT = dis.HAVE_ARGUMENT 55 56 57class ByteCodeInst(object): 58 ''' 59 Attributes 60 ---------- 61 - offset: 62 byte offset of opcode 63 - opcode: 64 opcode integer value 65 - arg: 66 instruction arg 67 - lineno: 68 -1 means unknown 69 ''' 70 __slots__ = 'offset', 'next', 'opcode', 'opname', 'arg', 'lineno' 71 72 def __init__(self, offset, opcode, arg, nextoffset): 73 self.offset = offset 74 self.next = nextoffset 75 self.opcode = opcode 76 self.opname = dis.opname[opcode] 77 self.arg = arg 78 self.lineno = -1 # unknown line number 79 80 @property 81 def is_jump(self): 82 return self.opcode in JUMP_OPS 83 84 @property 85 def is_terminator(self): 86 return self.opcode in TERM_OPS 87 88 def get_jump_target(self): 89 assert self.is_jump 90 if self.opcode in JREL_OPS: 91 return self.next + self.arg 92 else: 93 assert self.opcode in JABS_OPS 94 return self.arg 95 96 def __repr__(self): 97 return '%s(arg=%s, lineno=%d)' % (self.opname, self.arg, self.lineno) 98 99 @property 100 def block_effect(self): 101 """Effect of the block stack 102 Returns +1 (push), 0 (none) or -1 (pop) 103 """ 104 if self.opname.startswith('SETUP_'): 105 return 1 106 elif self.opname == 'POP_BLOCK': 107 return -1 108 else: 109 return 0 110 111 112CODE_LEN = 1 113ARG_LEN = 1 114NO_ARG_LEN = 1 115 116OPCODE_NOP = dis.opname.index('NOP') 117 118 119# Adapted from Lib/dis.py 120def _unpack_opargs(code): 121 """ 122 Returns a 4-int-tuple of 123 (bytecode offset, opcode, argument, offset of next bytecode). 124 """ 125 extended_arg = 0 126 n = len(code) 127 offset = i = 0 128 while i < n: 129 op = code[i] 130 i += CODE_LEN 131 if op >= HAVE_ARGUMENT: 132 arg = code[i] | extended_arg 133 for j in range(ARG_LEN): 134 arg |= code[i + j] << (8 * j) 135 i += ARG_LEN 136 if op == EXTENDED_ARG: 137 extended_arg = arg << 8 * ARG_LEN 138 continue 139 else: 140 arg = None 141 i += NO_ARG_LEN 142 143 extended_arg = 0 144 yield (offset, op, arg, i) 145 offset = i # Mark inst offset at first extended 146 147 148def _patched_opargs(bc_stream): 149 """Patch the bytecode stream. 150 151 - Adds a NOP bytecode at the start to avoid jump target being at the entry. 152 """ 153 # Injected NOP 154 yield (0, OPCODE_NOP, None, _FIXED_OFFSET) 155 # Adjust bytecode offset for the rest of the stream 156 for offset, opcode, arg, nextoffset in bc_stream: 157 # If the opcode has an absolute jump target, adjust it. 158 if opcode in JABS_OPS: 159 arg += _FIXED_OFFSET 160 yield offset + _FIXED_OFFSET, opcode, arg, nextoffset + _FIXED_OFFSET 161 162 163class ByteCodeIter(object): 164 def __init__(self, code): 165 self.code = code 166 self.iter = iter(_patched_opargs(_unpack_opargs(self.code.co_code))) 167 168 def __iter__(self): 169 return self 170 171 def _fetch_opcode(self): 172 return next(self.iter) 173 174 def next(self): 175 offset, opcode, arg, nextoffset = self._fetch_opcode() 176 return offset, ByteCodeInst(offset=offset, opcode=opcode, arg=arg, 177 nextoffset=nextoffset) 178 179 __next__ = next 180 181 def read_arg(self, size): 182 buf = 0 183 for i in range(size): 184 _offset, byte = next(self.iter) 185 buf |= byte << (8 * i) 186 return buf 187 188 189class ByteCode(object): 190 """ 191 The decoded bytecode of a function, and related information. 192 """ 193 __slots__ = ('func_id', 'co_names', 'co_varnames', 'co_consts', 194 'co_cellvars', 'co_freevars', 'table', 'labels') 195 196 def __init__(self, func_id): 197 code = func_id.code 198 199 labels = set(x + _FIXED_OFFSET for x in dis.findlabels(code.co_code)) 200 labels.add(0) 201 202 # A map of {offset: ByteCodeInst} 203 table = OrderedDict(ByteCodeIter(code)) 204 self._compute_lineno(table, code) 205 206 self.func_id = func_id 207 self.co_names = code.co_names 208 self.co_varnames = code.co_varnames 209 self.co_consts = code.co_consts 210 self.co_cellvars = code.co_cellvars 211 self.co_freevars = code.co_freevars 212 self.table = table 213 self.labels = sorted(labels) 214 215 @classmethod 216 def _compute_lineno(cls, table, code): 217 """ 218 Compute the line numbers for all bytecode instructions. 219 """ 220 for offset, lineno in dis.findlinestarts(code): 221 adj_offset = offset + _FIXED_OFFSET 222 if adj_offset in table: 223 table[adj_offset].lineno = lineno 224 # Assign unfilled lineno 225 # Start with first bytecode's lineno 226 known = table[_FIXED_OFFSET].lineno 227 for inst in table.values(): 228 if inst.lineno >= 0: 229 known = inst.lineno 230 else: 231 inst.lineno = known 232 return table 233 234 def __iter__(self): 235 return utils.itervalues(self.table) 236 237 def __getitem__(self, offset): 238 return self.table[offset] 239 240 def __contains__(self, offset): 241 return offset in self.table 242 243 def dump(self): 244 def label_marker(i): 245 if i[1].offset in self.labels: 246 return '>' 247 else: 248 return ' ' 249 250 return '\n'.join('%s %10s\t%s' % ((label_marker(i),) + i) 251 for i in utils.iteritems(self.table)) 252 253 @classmethod 254 def _compute_used_globals(cls, func, table, co_consts, co_names): 255 """ 256 Compute the globals used by the function with the given 257 bytecode table. 258 """ 259 d = {} 260 globs = func.__globals__ 261 builtins = globs.get('__builtins__', utils.builtins) 262 if isinstance(builtins, ModuleType): 263 builtins = builtins.__dict__ 264 # Look for LOAD_GLOBALs in the bytecode 265 for inst in table.values(): 266 if inst.opname == 'LOAD_GLOBAL': 267 name = co_names[inst.arg] 268 if name not in d: 269 try: 270 value = globs[name] 271 except KeyError: 272 value = builtins[name] 273 d[name] = value 274 # Add globals used by any nested code object 275 for co in co_consts: 276 if isinstance(co, CodeType): 277 subtable = OrderedDict(ByteCodeIter(co)) 278 d.update(cls._compute_used_globals(func, subtable, 279 co.co_consts, co.co_names)) 280 return d 281 282 def get_used_globals(self): 283 """ 284 Get a {name: value} map of the globals used by this code 285 object and any nested code objects. 286 """ 287 return self._compute_used_globals(self.func_id.func, self.table, 288 self.co_consts, self.co_names) 289 290 291class FunctionIdentity(serialize.ReduceMixin): 292 """ 293 A function's identity and metadata. 294 295 Note this typically represents a function whose bytecode is 296 being compiled, not necessarily the top-level user function 297 (the two might be distinct, e.g. in the `@generated_jit` case). 298 """ 299 _unique_ids = itertools.count(1) 300 301 @classmethod 302 def from_function(cls, pyfunc): 303 """ 304 Create the FunctionIdentity of the given function. 305 """ 306 func = get_function_object(pyfunc) 307 code = get_code_object(func) 308 pysig = utils.pysignature(func) 309 if not code: 310 raise errors.ByteCodeSupportError( 311 "%s does not provide its bytecode" % func) 312 313 try: 314 func_qualname = func.__qualname__ 315 except AttributeError: 316 func_qualname = func.__name__ 317 318 self = cls() 319 self.func = func 320 self.func_qualname = func_qualname 321 self.func_name = func_qualname.split('.')[-1] 322 self.code = code 323 self.module = inspect.getmodule(func) 324 self.modname = (utils._dynamic_modname 325 if self.module is None 326 else self.module.__name__) 327 self.is_generator = inspect.isgeneratorfunction(func) 328 self.pysig = pysig 329 self.filename = code.co_filename 330 self.firstlineno = code.co_firstlineno 331 self.arg_count = len(pysig.parameters) 332 self.arg_names = list(pysig.parameters) 333 334 # Even the same function definition can be compiled into 335 # several different function objects with distinct closure 336 # variables, so we make sure to disambiguate using an unique id. 337 uid = next(cls._unique_ids) 338 self.unique_name = '{}${}'.format(self.func_qualname, uid) 339 340 return self 341 342 def derive(self): 343 """Copy the object and increment the unique counter. 344 """ 345 return self.from_function(self.func) 346 347 def _reduce_states(self): 348 """ 349 NOTE: part of ReduceMixin protocol 350 """ 351 return dict(pyfunc=self.func) 352 353 @classmethod 354 def _rebuild(cls, pyfunc): 355 """ 356 NOTE: part of ReduceMixin protocol 357 """ 358 return cls.from_function(pyfunc) 359