1from __future__ import absolute_import
2
3import sys, os, re, inspect
4import imp
5
6try:
7    import hashlib
8except ImportError:
9    import md5 as hashlib
10
11from distutils.core import Distribution, Extension
12from distutils.command.build_ext import build_ext
13
14import Cython
15from ..Compiler.Main import Context, CompilationOptions, default_options
16
17from ..Compiler.ParseTreeTransforms import (CythonTransform,
18        SkipDeclarations, AnalyseDeclarationsTransform, EnvTransform)
19from ..Compiler.TreeFragment import parse_from_strings
20from ..Compiler.StringEncoding import _unicode
21from .Dependencies import strip_string_literals, cythonize, cached_function
22from ..Compiler import Pipeline, Nodes
23from ..Utils import get_cython_cache_dir
24import cython as cython_module
25
26IS_PY3 = sys.version_info >= (3, 0)
27
28# A utility function to convert user-supplied ASCII strings to unicode.
29if sys.version_info[0] < 3:
30    def to_unicode(s):
31        if isinstance(s, bytes):
32            return s.decode('ascii')
33        else:
34            return s
35else:
36    to_unicode = lambda x: x
37
38if sys.version_info < (3, 5):
39    import imp
40    def load_dynamic(name, module_path):
41        return imp.load_dynamic(name, module_path)
42else:
43    import importlib.util as _importlib_util
44    def load_dynamic(name, module_path):
45        spec = _importlib_util.spec_from_file_location(name, module_path)
46        module = _importlib_util.module_from_spec(spec)
47        # sys.modules[name] = module
48        spec.loader.exec_module(module)
49        return module
50
51class UnboundSymbols(EnvTransform, SkipDeclarations):
52    def __init__(self):
53        CythonTransform.__init__(self, None)
54        self.unbound = set()
55    def visit_NameNode(self, node):
56        if not self.current_env().lookup(node.name):
57            self.unbound.add(node.name)
58        return node
59    def __call__(self, node):
60        super(UnboundSymbols, self).__call__(node)
61        return self.unbound
62
63
64@cached_function
65def unbound_symbols(code, context=None):
66    code = to_unicode(code)
67    if context is None:
68        context = Context([], default_options)
69    from ..Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
70    tree = parse_from_strings('(tree fragment)', code)
71    for phase in Pipeline.create_pipeline(context, 'pyx'):
72        if phase is None:
73            continue
74        tree = phase(tree)
75        if isinstance(phase, AnalyseDeclarationsTransform):
76            break
77    try:
78        import builtins
79    except ImportError:
80        import __builtin__ as builtins
81    return tuple(UnboundSymbols()(tree) - set(dir(builtins)))
82
83
84def unsafe_type(arg, context=None):
85    py_type = type(arg)
86    if py_type is int:
87        return 'long'
88    else:
89        return safe_type(arg, context)
90
91
92def safe_type(arg, context=None):
93    py_type = type(arg)
94    if py_type in (list, tuple, dict, str):
95        return py_type.__name__
96    elif py_type is complex:
97        return 'double complex'
98    elif py_type is float:
99        return 'double'
100    elif py_type is bool:
101        return 'bint'
102    elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
103        return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
104    else:
105        for base_type in py_type.__mro__:
106            if base_type.__module__ in ('__builtin__', 'builtins'):
107                return 'object'
108            module = context.find_module(base_type.__module__, need_pxd=False)
109            if module:
110                entry = module.lookup(base_type.__name__)
111                if entry.is_type:
112                    return '%s.%s' % (base_type.__module__, base_type.__name__)
113        return 'object'
114
115
116def _get_build_extension():
117    dist = Distribution()
118    # Ensure the build respects distutils configuration by parsing
119    # the configuration files
120    config_files = dist.find_config_files()
121    dist.parse_config_files(config_files)
122    build_extension = build_ext(dist)
123    build_extension.finalize_options()
124    return build_extension
125
126
127@cached_function
128def _create_context(cython_include_dirs):
129    return Context(list(cython_include_dirs), default_options)
130
131
132_cython_inline_cache = {}
133_cython_inline_default_context = _create_context(('.',))
134
135def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None):
136    for symbol in unbound_symbols:
137        if symbol not in kwds:
138            if locals is None or globals is None:
139                calling_frame = inspect.currentframe().f_back.f_back.f_back
140                if locals is None:
141                    locals = calling_frame.f_locals
142                if globals is None:
143                    globals = calling_frame.f_globals
144            if symbol in locals:
145                kwds[symbol] = locals[symbol]
146            elif symbol in globals:
147                kwds[symbol] = globals[symbol]
148            else:
149                print("Couldn't find %r" % symbol)
150
151def _inline_key(orig_code, arg_sigs, language_level):
152    key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__
153    return hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest()
154
155def cython_inline(code, get_type=unsafe_type,
156                  lib_dir=os.path.join(get_cython_cache_dir(), 'inline'),
157                  cython_include_dirs=None, cython_compiler_directives=None,
158                  force=False, quiet=False, locals=None, globals=None, language_level=None, **kwds):
159
160    if get_type is None:
161        get_type = lambda x: 'object'
162    ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context
163
164    cython_compiler_directives = dict(cython_compiler_directives or {})
165    if language_level is None and 'language_level' not in cython_compiler_directives:
166        language_level = '3str'
167    if language_level is not None:
168        cython_compiler_directives['language_level'] = language_level
169
170    # Fast path if this has been called in this session.
171    _unbound_symbols = _cython_inline_cache.get(code)
172    if _unbound_symbols is not None:
173        _populate_unbound(kwds, _unbound_symbols, locals, globals)
174        args = sorted(kwds.items())
175        arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args])
176        key_hash = _inline_key(code, arg_sigs, language_level)
177        invoke = _cython_inline_cache.get((code, arg_sigs, key_hash))
178        if invoke is not None:
179            arg_list = [arg[1] for arg in args]
180            return invoke(*arg_list)
181
182    orig_code = code
183    code = to_unicode(code)
184    code, literals = strip_string_literals(code)
185    code = strip_common_indent(code)
186    if locals is None:
187        locals = inspect.currentframe().f_back.f_back.f_locals
188    if globals is None:
189        globals = inspect.currentframe().f_back.f_back.f_globals
190    try:
191        _cython_inline_cache[orig_code] = _unbound_symbols = unbound_symbols(code)
192        _populate_unbound(kwds, _unbound_symbols, locals, globals)
193    except AssertionError:
194        if not quiet:
195            # Parsing from strings not fully supported (e.g. cimports).
196            print("Could not parse code as a string (to extract unbound symbols).")
197
198    cimports = []
199    for name, arg in list(kwds.items()):
200        if arg is cython_module:
201            cimports.append('\ncimport cython as %s' % name)
202            del kwds[name]
203    arg_names = sorted(kwds)
204    arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
205    key_hash = _inline_key(orig_code, arg_sigs, language_level)
206    module_name = "_cython_inline_" + key_hash
207
208    if module_name in sys.modules:
209        module = sys.modules[module_name]
210
211    else:
212        build_extension = None
213        if cython_inline.so_ext is None:
214            # Figure out and cache current extension suffix
215            build_extension = _get_build_extension()
216            cython_inline.so_ext = build_extension.get_ext_filename('')
217
218        module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext)
219
220        if not os.path.exists(lib_dir):
221            os.makedirs(lib_dir)
222        if force or not os.path.isfile(module_path):
223            cflags = []
224            c_include_dirs = []
225            qualified = re.compile(r'([.\w]+)[.]')
226            for type, _ in arg_sigs:
227                m = qualified.match(type)
228                if m:
229                    cimports.append('\ncimport %s' % m.groups()[0])
230                    # one special case
231                    if m.groups()[0] == 'numpy':
232                        import numpy
233                        c_include_dirs.append(numpy.get_include())
234                        # cflags.append('-Wno-unused')
235            module_body, func_body = extract_func_code(code)
236            params = ', '.join(['%s %s' % a for a in arg_sigs])
237            module_code = """
238%(module_body)s
239%(cimports)s
240def __invoke(%(params)s):
241%(func_body)s
242    return locals()
243            """ % {'cimports': '\n'.join(cimports),
244                   'module_body': module_body,
245                   'params': params,
246                   'func_body': func_body }
247            for key, value in literals.items():
248                module_code = module_code.replace(key, value)
249            pyx_file = os.path.join(lib_dir, module_name + '.pyx')
250            fh = open(pyx_file, 'w')
251            try:
252                fh.write(module_code)
253            finally:
254                fh.close()
255            extension = Extension(
256                name = module_name,
257                sources = [pyx_file],
258                include_dirs = c_include_dirs,
259                extra_compile_args = cflags)
260            if build_extension is None:
261                build_extension = _get_build_extension()
262            build_extension.extensions = cythonize(
263                [extension],
264                include_path=cython_include_dirs or ['.'],
265                compiler_directives=cython_compiler_directives,
266                quiet=quiet)
267            build_extension.build_temp = os.path.dirname(pyx_file)
268            build_extension.build_lib  = lib_dir
269            build_extension.run()
270
271        module = load_dynamic(module_name, module_path)
272
273    _cython_inline_cache[orig_code, arg_sigs, key_hash] = module.__invoke
274    arg_list = [kwds[arg] for arg in arg_names]
275    return module.__invoke(*arg_list)
276
277# Cached suffix used by cython_inline above.  None should get
278# overridden with actual value upon the first cython_inline invocation
279cython_inline.so_ext = None
280
281_find_non_space = re.compile('[^ ]').search
282
283
284def strip_common_indent(code):
285    min_indent = None
286    lines = code.splitlines()
287    for line in lines:
288        match = _find_non_space(line)
289        if not match:
290            continue  # blank
291        indent = match.start()
292        if line[indent] == '#':
293            continue  # comment
294        if min_indent is None or min_indent > indent:
295            min_indent = indent
296    for ix, line in enumerate(lines):
297        match = _find_non_space(line)
298        if not match or not line or line[indent:indent+1] == '#':
299            continue
300        lines[ix] = line[min_indent:]
301    return '\n'.join(lines)
302
303
304module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
305def extract_func_code(code):
306    module = []
307    function = []
308    current = function
309    code = code.replace('\t', ' ')
310    lines = code.split('\n')
311    for line in lines:
312        if not line.startswith(' '):
313            if module_statement.match(line):
314                current = module
315            else:
316                current = function
317        current.append(line)
318    return '\n'.join(module), '    ' + '\n    '.join(function)
319
320
321try:
322    from inspect import getcallargs
323except ImportError:
324    def getcallargs(func, *arg_values, **kwd_values):
325        all = {}
326        args, varargs, kwds, defaults = inspect.getargspec(func)
327        if varargs is not None:
328            all[varargs] = arg_values[len(args):]
329        for name, value in zip(args, arg_values):
330            all[name] = value
331        for name, value in list(kwd_values.items()):
332            if name in args:
333                if name in all:
334                    raise TypeError("Duplicate argument %s" % name)
335                all[name] = kwd_values.pop(name)
336        if kwds is not None:
337            all[kwds] = kwd_values
338        elif kwd_values:
339            raise TypeError("Unexpected keyword arguments: %s" % list(kwd_values))
340        if defaults is None:
341            defaults = ()
342        first_default = len(args) - len(defaults)
343        for ix, name in enumerate(args):
344            if name not in all:
345                if ix >= first_default:
346                    all[name] = defaults[ix - first_default]
347                else:
348                    raise TypeError("Missing argument: %s" % name)
349        return all
350
351
352def get_body(source):
353    ix = source.index(':')
354    if source[:5] == 'lambda':
355        return "return %s" % source[ix+1:]
356    else:
357        return source[ix+1:]
358
359
360# Lots to be done here... It would be especially cool if compiled functions
361# could invoke each other quickly.
362class RuntimeCompiledFunction(object):
363
364    def __init__(self, f):
365        self._f = f
366        self._body = get_body(inspect.getsource(f))
367
368    def __call__(self, *args, **kwds):
369        all = getcallargs(self._f, *args, **kwds)
370        if IS_PY3:
371            return cython_inline(self._body, locals=self._f.__globals__, globals=self._f.__globals__, **all)
372        else:
373            return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)
374