1from __future__ import absolute_import 2 3import sys, os, re, inspect 4import imp 5 6try: 7 import hashlib 8except ImportError: 9 import md5 as hashlib 10 11from distutils.core import Distribution, Extension 12from distutils.command.build_ext import build_ext 13 14import Cython 15from ..Compiler.Main import Context, CompilationOptions, default_options 16 17from ..Compiler.ParseTreeTransforms import (CythonTransform, 18 SkipDeclarations, AnalyseDeclarationsTransform, EnvTransform) 19from ..Compiler.TreeFragment import parse_from_strings 20from ..Compiler.StringEncoding import _unicode 21from .Dependencies import strip_string_literals, cythonize, cached_function 22from ..Compiler import Pipeline, Nodes 23from ..Utils import get_cython_cache_dir 24import cython as cython_module 25 26IS_PY3 = sys.version_info >= (3, 0) 27 28# A utility function to convert user-supplied ASCII strings to unicode. 29if sys.version_info[0] < 3: 30 def to_unicode(s): 31 if isinstance(s, bytes): 32 return s.decode('ascii') 33 else: 34 return s 35else: 36 to_unicode = lambda x: x 37 38 39class UnboundSymbols(EnvTransform, SkipDeclarations): 40 def __init__(self): 41 CythonTransform.__init__(self, None) 42 self.unbound = set() 43 def visit_NameNode(self, node): 44 if not self.current_env().lookup(node.name): 45 self.unbound.add(node.name) 46 return node 47 def __call__(self, node): 48 super(UnboundSymbols, self).__call__(node) 49 return self.unbound 50 51 52@cached_function 53def unbound_symbols(code, context=None): 54 code = to_unicode(code) 55 if context is None: 56 context = Context([], default_options) 57 from ..Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform 58 tree = parse_from_strings('(tree fragment)', code) 59 for phase in Pipeline.create_pipeline(context, 'pyx'): 60 if phase is None: 61 continue 62 tree = phase(tree) 63 if isinstance(phase, AnalyseDeclarationsTransform): 64 break 65 try: 66 import builtins 67 except ImportError: 68 import __builtin__ as builtins 69 return tuple(UnboundSymbols()(tree) - set(dir(builtins))) 70 71 72def unsafe_type(arg, context=None): 73 py_type = type(arg) 74 if py_type is int: 75 return 'long' 76 else: 77 return safe_type(arg, context) 78 79 80def safe_type(arg, context=None): 81 py_type = type(arg) 82 if py_type in (list, tuple, dict, str): 83 return py_type.__name__ 84 elif py_type is complex: 85 return 'double complex' 86 elif py_type is float: 87 return 'double' 88 elif py_type is bool: 89 return 'bint' 90 elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray): 91 return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim) 92 else: 93 for base_type in py_type.__mro__: 94 if base_type.__module__ in ('__builtin__', 'builtins'): 95 return 'object' 96 module = context.find_module(base_type.__module__, need_pxd=False) 97 if module: 98 entry = module.lookup(base_type.__name__) 99 if entry.is_type: 100 return '%s.%s' % (base_type.__module__, base_type.__name__) 101 return 'object' 102 103 104def _get_build_extension(): 105 dist = Distribution() 106 # Ensure the build respects distutils configuration by parsing 107 # the configuration files 108 config_files = dist.find_config_files() 109 dist.parse_config_files(config_files) 110 build_extension = build_ext(dist) 111 build_extension.finalize_options() 112 return build_extension 113 114 115@cached_function 116def _create_context(cython_include_dirs): 117 return Context(list(cython_include_dirs), default_options) 118 119 120_cython_inline_cache = {} 121_cython_inline_default_context = _create_context(('.',)) 122 123def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None): 124 for symbol in unbound_symbols: 125 if symbol not in kwds: 126 if locals is None or globals is None: 127 calling_frame = inspect.currentframe().f_back.f_back.f_back 128 if locals is None: 129 locals = calling_frame.f_locals 130 if globals is None: 131 globals = calling_frame.f_globals 132 if symbol in locals: 133 kwds[symbol] = locals[symbol] 134 elif symbol in globals: 135 kwds[symbol] = globals[symbol] 136 else: 137 print("Couldn't find %r" % symbol) 138 139def cython_inline(code, get_type=unsafe_type, 140 lib_dir=os.path.join(get_cython_cache_dir(), 'inline'), 141 cython_include_dirs=None, cython_compiler_directives=None, 142 force=False, quiet=False, locals=None, globals=None, language_level=None, **kwds): 143 144 if get_type is None: 145 get_type = lambda x: 'object' 146 ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context 147 148 # Fast path if this has been called in this session. 149 _unbound_symbols = _cython_inline_cache.get(code) 150 if _unbound_symbols is not None: 151 _populate_unbound(kwds, _unbound_symbols, locals, globals) 152 args = sorted(kwds.items()) 153 arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args]) 154 invoke = _cython_inline_cache.get((code, arg_sigs)) 155 if invoke is not None: 156 arg_list = [arg[1] for arg in args] 157 return invoke(*arg_list) 158 159 orig_code = code 160 code = to_unicode(code) 161 code, literals = strip_string_literals(code) 162 code = strip_common_indent(code) 163 if locals is None: 164 locals = inspect.currentframe().f_back.f_back.f_locals 165 if globals is None: 166 globals = inspect.currentframe().f_back.f_back.f_globals 167 try: 168 _cython_inline_cache[orig_code] = _unbound_symbols = unbound_symbols(code) 169 _populate_unbound(kwds, _unbound_symbols, locals, globals) 170 except AssertionError: 171 if not quiet: 172 # Parsing from strings not fully supported (e.g. cimports). 173 print("Could not parse code as a string (to extract unbound symbols).") 174 175 cython_compiler_directives = dict(cython_compiler_directives or {}) 176 if language_level is not None: 177 cython_compiler_directives['language_level'] = language_level 178 179 cimports = [] 180 for name, arg in list(kwds.items()): 181 if arg is cython_module: 182 cimports.append('\ncimport cython as %s' % name) 183 del kwds[name] 184 arg_names = sorted(kwds) 185 arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names]) 186 key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__ 187 module_name = "_cython_inline_" + hashlib.md5(_unicode(key).encode('utf-8')).hexdigest() 188 189 if module_name in sys.modules: 190 module = sys.modules[module_name] 191 192 else: 193 build_extension = None 194 if cython_inline.so_ext is None: 195 # Figure out and cache current extension suffix 196 build_extension = _get_build_extension() 197 cython_inline.so_ext = build_extension.get_ext_filename('') 198 199 module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext) 200 201 if not os.path.exists(lib_dir): 202 os.makedirs(lib_dir) 203 if force or not os.path.isfile(module_path): 204 cflags = [] 205 c_include_dirs = [] 206 qualified = re.compile(r'([.\w]+)[.]') 207 for type, _ in arg_sigs: 208 m = qualified.match(type) 209 if m: 210 cimports.append('\ncimport %s' % m.groups()[0]) 211 # one special case 212 if m.groups()[0] == 'numpy': 213 import numpy 214 c_include_dirs.append(numpy.get_include()) 215 # cflags.append('-Wno-unused') 216 module_body, func_body = extract_func_code(code) 217 params = ', '.join(['%s %s' % a for a in arg_sigs]) 218 module_code = """ 219%(module_body)s 220%(cimports)s 221def __invoke(%(params)s): 222%(func_body)s 223 return locals() 224 """ % {'cimports': '\n'.join(cimports), 225 'module_body': module_body, 226 'params': params, 227 'func_body': func_body } 228 for key, value in literals.items(): 229 module_code = module_code.replace(key, value) 230 pyx_file = os.path.join(lib_dir, module_name + '.pyx') 231 fh = open(pyx_file, 'w') 232 try: 233 fh.write(module_code) 234 finally: 235 fh.close() 236 extension = Extension( 237 name = module_name, 238 sources = [pyx_file], 239 include_dirs = c_include_dirs, 240 extra_compile_args = cflags) 241 if build_extension is None: 242 build_extension = _get_build_extension() 243 build_extension.extensions = cythonize( 244 [extension], 245 include_path=cython_include_dirs or ['.'], 246 compiler_directives=cython_compiler_directives, 247 quiet=quiet) 248 build_extension.build_temp = os.path.dirname(pyx_file) 249 build_extension.build_lib = lib_dir 250 build_extension.run() 251 252 module = imp.load_dynamic(module_name, module_path) 253 254 _cython_inline_cache[orig_code, arg_sigs] = module.__invoke 255 arg_list = [kwds[arg] for arg in arg_names] 256 return module.__invoke(*arg_list) 257 258# Cached suffix used by cython_inline above. None should get 259# overridden with actual value upon the first cython_inline invocation 260cython_inline.so_ext = None 261 262_find_non_space = re.compile('[^ ]').search 263 264 265def strip_common_indent(code): 266 min_indent = None 267 lines = code.splitlines() 268 for line in lines: 269 match = _find_non_space(line) 270 if not match: 271 continue # blank 272 indent = match.start() 273 if line[indent] == '#': 274 continue # comment 275 if min_indent is None or min_indent > indent: 276 min_indent = indent 277 for ix, line in enumerate(lines): 278 match = _find_non_space(line) 279 if not match or not line or line[indent:indent+1] == '#': 280 continue 281 lines[ix] = line[min_indent:] 282 return '\n'.join(lines) 283 284 285module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))') 286def extract_func_code(code): 287 module = [] 288 function = [] 289 current = function 290 code = code.replace('\t', ' ') 291 lines = code.split('\n') 292 for line in lines: 293 if not line.startswith(' '): 294 if module_statement.match(line): 295 current = module 296 else: 297 current = function 298 current.append(line) 299 return '\n'.join(module), ' ' + '\n '.join(function) 300 301 302try: 303 from inspect import getcallargs 304except ImportError: 305 def getcallargs(func, *arg_values, **kwd_values): 306 all = {} 307 args, varargs, kwds, defaults = inspect.getargspec(func) 308 if varargs is not None: 309 all[varargs] = arg_values[len(args):] 310 for name, value in zip(args, arg_values): 311 all[name] = value 312 for name, value in list(kwd_values.items()): 313 if name in args: 314 if name in all: 315 raise TypeError("Duplicate argument %s" % name) 316 all[name] = kwd_values.pop(name) 317 if kwds is not None: 318 all[kwds] = kwd_values 319 elif kwd_values: 320 raise TypeError("Unexpected keyword arguments: %s" % list(kwd_values)) 321 if defaults is None: 322 defaults = () 323 first_default = len(args) - len(defaults) 324 for ix, name in enumerate(args): 325 if name not in all: 326 if ix >= first_default: 327 all[name] = defaults[ix - first_default] 328 else: 329 raise TypeError("Missing argument: %s" % name) 330 return all 331 332 333def get_body(source): 334 ix = source.index(':') 335 if source[:5] == 'lambda': 336 return "return %s" % source[ix+1:] 337 else: 338 return source[ix+1:] 339 340 341# Lots to be done here... It would be especially cool if compiled functions 342# could invoke each other quickly. 343class RuntimeCompiledFunction(object): 344 345 def __init__(self, f): 346 self._f = f 347 self._body = get_body(inspect.getsource(f)) 348 349 def __call__(self, *args, **kwds): 350 all = getcallargs(self._f, *args, **kwds) 351 if IS_PY3: 352 return cython_inline(self._body, locals=self._f.__globals__, globals=self._f.__globals__, **all) 353 else: 354 return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all) 355