1from __future__ import absolute_import 2 3import sys, os, re, inspect 4import imp 5 6try: 7 import hashlib 8except ImportError: 9 import md5 as hashlib 10 11from distutils.core import Distribution, Extension 12from distutils.command.build_ext import build_ext 13 14import Cython 15from ..Compiler.Main import Context, CompilationOptions, default_options 16 17from ..Compiler.ParseTreeTransforms import (CythonTransform, 18 SkipDeclarations, AnalyseDeclarationsTransform, EnvTransform) 19from ..Compiler.TreeFragment import parse_from_strings 20from ..Compiler.StringEncoding import _unicode 21from .Dependencies import strip_string_literals, cythonize, cached_function 22from ..Compiler import Pipeline, Nodes 23from ..Utils import get_cython_cache_dir 24import cython as cython_module 25 26IS_PY3 = sys.version_info >= (3, 0) 27 28# A utility function to convert user-supplied ASCII strings to unicode. 29if sys.version_info[0] < 3: 30 def to_unicode(s): 31 if isinstance(s, bytes): 32 return s.decode('ascii') 33 else: 34 return s 35else: 36 to_unicode = lambda x: x 37 38if sys.version_info < (3, 5): 39 import imp 40 def load_dynamic(name, module_path): 41 return imp.load_dynamic(name, module_path) 42else: 43 import importlib.util as _importlib_util 44 def load_dynamic(name, module_path): 45 spec = _importlib_util.spec_from_file_location(name, module_path) 46 module = _importlib_util.module_from_spec(spec) 47 # sys.modules[name] = module 48 spec.loader.exec_module(module) 49 return module 50 51class UnboundSymbols(EnvTransform, SkipDeclarations): 52 def __init__(self): 53 CythonTransform.__init__(self, None) 54 self.unbound = set() 55 def visit_NameNode(self, node): 56 if not self.current_env().lookup(node.name): 57 self.unbound.add(node.name) 58 return node 59 def __call__(self, node): 60 super(UnboundSymbols, self).__call__(node) 61 return self.unbound 62 63 64@cached_function 65def unbound_symbols(code, context=None): 66 code = to_unicode(code) 67 if context is None: 68 context = Context([], default_options) 69 from ..Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform 70 tree = parse_from_strings('(tree fragment)', code) 71 for phase in Pipeline.create_pipeline(context, 'pyx'): 72 if phase is None: 73 continue 74 tree = phase(tree) 75 if isinstance(phase, AnalyseDeclarationsTransform): 76 break 77 try: 78 import builtins 79 except ImportError: 80 import __builtin__ as builtins 81 return tuple(UnboundSymbols()(tree) - set(dir(builtins))) 82 83 84def unsafe_type(arg, context=None): 85 py_type = type(arg) 86 if py_type is int: 87 return 'long' 88 else: 89 return safe_type(arg, context) 90 91 92def safe_type(arg, context=None): 93 py_type = type(arg) 94 if py_type in (list, tuple, dict, str): 95 return py_type.__name__ 96 elif py_type is complex: 97 return 'double complex' 98 elif py_type is float: 99 return 'double' 100 elif py_type is bool: 101 return 'bint' 102 elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray): 103 return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim) 104 else: 105 for base_type in py_type.__mro__: 106 if base_type.__module__ in ('__builtin__', 'builtins'): 107 return 'object' 108 module = context.find_module(base_type.__module__, need_pxd=False) 109 if module: 110 entry = module.lookup(base_type.__name__) 111 if entry.is_type: 112 return '%s.%s' % (base_type.__module__, base_type.__name__) 113 return 'object' 114 115 116def _get_build_extension(): 117 dist = Distribution() 118 # Ensure the build respects distutils configuration by parsing 119 # the configuration files 120 config_files = dist.find_config_files() 121 dist.parse_config_files(config_files) 122 build_extension = build_ext(dist) 123 build_extension.finalize_options() 124 return build_extension 125 126 127@cached_function 128def _create_context(cython_include_dirs): 129 return Context(list(cython_include_dirs), default_options) 130 131 132_cython_inline_cache = {} 133_cython_inline_default_context = _create_context(('.',)) 134 135def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None): 136 for symbol in unbound_symbols: 137 if symbol not in kwds: 138 if locals is None or globals is None: 139 calling_frame = inspect.currentframe().f_back.f_back.f_back 140 if locals is None: 141 locals = calling_frame.f_locals 142 if globals is None: 143 globals = calling_frame.f_globals 144 if symbol in locals: 145 kwds[symbol] = locals[symbol] 146 elif symbol in globals: 147 kwds[symbol] = globals[symbol] 148 else: 149 print("Couldn't find %r" % symbol) 150 151def _inline_key(orig_code, arg_sigs, language_level): 152 key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__ 153 return hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest() 154 155def cython_inline(code, get_type=unsafe_type, 156 lib_dir=os.path.join(get_cython_cache_dir(), 'inline'), 157 cython_include_dirs=None, cython_compiler_directives=None, 158 force=False, quiet=False, locals=None, globals=None, language_level=None, **kwds): 159 160 if get_type is None: 161 get_type = lambda x: 'object' 162 ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context 163 164 cython_compiler_directives = dict(cython_compiler_directives or {}) 165 if language_level is None and 'language_level' not in cython_compiler_directives: 166 language_level = '3str' 167 if language_level is not None: 168 cython_compiler_directives['language_level'] = language_level 169 170 # Fast path if this has been called in this session. 171 _unbound_symbols = _cython_inline_cache.get(code) 172 if _unbound_symbols is not None: 173 _populate_unbound(kwds, _unbound_symbols, locals, globals) 174 args = sorted(kwds.items()) 175 arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args]) 176 key_hash = _inline_key(code, arg_sigs, language_level) 177 invoke = _cython_inline_cache.get((code, arg_sigs, key_hash)) 178 if invoke is not None: 179 arg_list = [arg[1] for arg in args] 180 return invoke(*arg_list) 181 182 orig_code = code 183 code = to_unicode(code) 184 code, literals = strip_string_literals(code) 185 code = strip_common_indent(code) 186 if locals is None: 187 locals = inspect.currentframe().f_back.f_back.f_locals 188 if globals is None: 189 globals = inspect.currentframe().f_back.f_back.f_globals 190 try: 191 _cython_inline_cache[orig_code] = _unbound_symbols = unbound_symbols(code) 192 _populate_unbound(kwds, _unbound_symbols, locals, globals) 193 except AssertionError: 194 if not quiet: 195 # Parsing from strings not fully supported (e.g. cimports). 196 print("Could not parse code as a string (to extract unbound symbols).") 197 198 cimports = [] 199 for name, arg in list(kwds.items()): 200 if arg is cython_module: 201 cimports.append('\ncimport cython as %s' % name) 202 del kwds[name] 203 arg_names = sorted(kwds) 204 arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names]) 205 key_hash = _inline_key(orig_code, arg_sigs, language_level) 206 module_name = "_cython_inline_" + key_hash 207 208 if module_name in sys.modules: 209 module = sys.modules[module_name] 210 211 else: 212 build_extension = None 213 if cython_inline.so_ext is None: 214 # Figure out and cache current extension suffix 215 build_extension = _get_build_extension() 216 cython_inline.so_ext = build_extension.get_ext_filename('') 217 218 module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext) 219 220 if not os.path.exists(lib_dir): 221 os.makedirs(lib_dir) 222 if force or not os.path.isfile(module_path): 223 cflags = [] 224 c_include_dirs = [] 225 qualified = re.compile(r'([.\w]+)[.]') 226 for type, _ in arg_sigs: 227 m = qualified.match(type) 228 if m: 229 cimports.append('\ncimport %s' % m.groups()[0]) 230 # one special case 231 if m.groups()[0] == 'numpy': 232 import numpy 233 c_include_dirs.append(numpy.get_include()) 234 # cflags.append('-Wno-unused') 235 module_body, func_body = extract_func_code(code) 236 params = ', '.join(['%s %s' % a for a in arg_sigs]) 237 module_code = """ 238%(module_body)s 239%(cimports)s 240def __invoke(%(params)s): 241%(func_body)s 242 return locals() 243 """ % {'cimports': '\n'.join(cimports), 244 'module_body': module_body, 245 'params': params, 246 'func_body': func_body } 247 for key, value in literals.items(): 248 module_code = module_code.replace(key, value) 249 pyx_file = os.path.join(lib_dir, module_name + '.pyx') 250 fh = open(pyx_file, 'w') 251 try: 252 fh.write(module_code) 253 finally: 254 fh.close() 255 extension = Extension( 256 name = module_name, 257 sources = [pyx_file], 258 include_dirs = c_include_dirs, 259 extra_compile_args = cflags) 260 if build_extension is None: 261 build_extension = _get_build_extension() 262 build_extension.extensions = cythonize( 263 [extension], 264 include_path=cython_include_dirs or ['.'], 265 compiler_directives=cython_compiler_directives, 266 quiet=quiet) 267 build_extension.build_temp = os.path.dirname(pyx_file) 268 build_extension.build_lib = lib_dir 269 build_extension.run() 270 271 module = load_dynamic(module_name, module_path) 272 273 _cython_inline_cache[orig_code, arg_sigs, key_hash] = module.__invoke 274 arg_list = [kwds[arg] for arg in arg_names] 275 return module.__invoke(*arg_list) 276 277# Cached suffix used by cython_inline above. None should get 278# overridden with actual value upon the first cython_inline invocation 279cython_inline.so_ext = None 280 281_find_non_space = re.compile('[^ ]').search 282 283 284def strip_common_indent(code): 285 min_indent = None 286 lines = code.splitlines() 287 for line in lines: 288 match = _find_non_space(line) 289 if not match: 290 continue # blank 291 indent = match.start() 292 if line[indent] == '#': 293 continue # comment 294 if min_indent is None or min_indent > indent: 295 min_indent = indent 296 for ix, line in enumerate(lines): 297 match = _find_non_space(line) 298 if not match or not line or line[indent:indent+1] == '#': 299 continue 300 lines[ix] = line[min_indent:] 301 return '\n'.join(lines) 302 303 304module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))') 305def extract_func_code(code): 306 module = [] 307 function = [] 308 current = function 309 code = code.replace('\t', ' ') 310 lines = code.split('\n') 311 for line in lines: 312 if not line.startswith(' '): 313 if module_statement.match(line): 314 current = module 315 else: 316 current = function 317 current.append(line) 318 return '\n'.join(module), ' ' + '\n '.join(function) 319 320 321try: 322 from inspect import getcallargs 323except ImportError: 324 def getcallargs(func, *arg_values, **kwd_values): 325 all = {} 326 args, varargs, kwds, defaults = inspect.getargspec(func) 327 if varargs is not None: 328 all[varargs] = arg_values[len(args):] 329 for name, value in zip(args, arg_values): 330 all[name] = value 331 for name, value in list(kwd_values.items()): 332 if name in args: 333 if name in all: 334 raise TypeError("Duplicate argument %s" % name) 335 all[name] = kwd_values.pop(name) 336 if kwds is not None: 337 all[kwds] = kwd_values 338 elif kwd_values: 339 raise TypeError("Unexpected keyword arguments: %s" % list(kwd_values)) 340 if defaults is None: 341 defaults = () 342 first_default = len(args) - len(defaults) 343 for ix, name in enumerate(args): 344 if name not in all: 345 if ix >= first_default: 346 all[name] = defaults[ix - first_default] 347 else: 348 raise TypeError("Missing argument: %s" % name) 349 return all 350 351 352def get_body(source): 353 ix = source.index(':') 354 if source[:5] == 'lambda': 355 return "return %s" % source[ix+1:] 356 else: 357 return source[ix+1:] 358 359 360# Lots to be done here... It would be especially cool if compiled functions 361# could invoke each other quickly. 362class RuntimeCompiledFunction(object): 363 364 def __init__(self, f): 365 self._f = f 366 self._body = get_body(inspect.getsource(f)) 367 368 def __call__(self, *args, **kwds): 369 all = getcallargs(self._f, *args, **kwds) 370 if IS_PY3: 371 return cython_inline(self._body, locals=self._f.__globals__, globals=self._f.__globals__, **all) 372 else: 373 return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all) 374