1from __future__ import absolute_import
2
3import sys, os, re, inspect
4import imp
5
6try:
7    import hashlib
8except ImportError:
9    import md5 as hashlib
10
11from distutils.core import Distribution, Extension
12from distutils.command.build_ext import build_ext
13
14import Cython
15from ..Compiler.Main import Context, CompilationOptions, default_options
16
17from ..Compiler.ParseTreeTransforms import (CythonTransform,
18        SkipDeclarations, AnalyseDeclarationsTransform, EnvTransform)
19from ..Compiler.TreeFragment import parse_from_strings
20from ..Compiler.StringEncoding import _unicode
21from .Dependencies import strip_string_literals, cythonize, cached_function
22from ..Compiler import Pipeline, Nodes
23from ..Utils import get_cython_cache_dir
24import cython as cython_module
25
26IS_PY3 = sys.version_info >= (3, 0)
27
28# A utility function to convert user-supplied ASCII strings to unicode.
29if sys.version_info[0] < 3:
30    def to_unicode(s):
31        if isinstance(s, bytes):
32            return s.decode('ascii')
33        else:
34            return s
35else:
36    to_unicode = lambda x: x
37
38
39class UnboundSymbols(EnvTransform, SkipDeclarations):
40    def __init__(self):
41        CythonTransform.__init__(self, None)
42        self.unbound = set()
43    def visit_NameNode(self, node):
44        if not self.current_env().lookup(node.name):
45            self.unbound.add(node.name)
46        return node
47    def __call__(self, node):
48        super(UnboundSymbols, self).__call__(node)
49        return self.unbound
50
51
52@cached_function
53def unbound_symbols(code, context=None):
54    code = to_unicode(code)
55    if context is None:
56        context = Context([], default_options)
57    from ..Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
58    tree = parse_from_strings('(tree fragment)', code)
59    for phase in Pipeline.create_pipeline(context, 'pyx'):
60        if phase is None:
61            continue
62        tree = phase(tree)
63        if isinstance(phase, AnalyseDeclarationsTransform):
64            break
65    try:
66        import builtins
67    except ImportError:
68        import __builtin__ as builtins
69    return tuple(UnboundSymbols()(tree) - set(dir(builtins)))
70
71
72def unsafe_type(arg, context=None):
73    py_type = type(arg)
74    if py_type is int:
75        return 'long'
76    else:
77        return safe_type(arg, context)
78
79
80def safe_type(arg, context=None):
81    py_type = type(arg)
82    if py_type in (list, tuple, dict, str):
83        return py_type.__name__
84    elif py_type is complex:
85        return 'double complex'
86    elif py_type is float:
87        return 'double'
88    elif py_type is bool:
89        return 'bint'
90    elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
91        return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
92    else:
93        for base_type in py_type.__mro__:
94            if base_type.__module__ in ('__builtin__', 'builtins'):
95                return 'object'
96            module = context.find_module(base_type.__module__, need_pxd=False)
97            if module:
98                entry = module.lookup(base_type.__name__)
99                if entry.is_type:
100                    return '%s.%s' % (base_type.__module__, base_type.__name__)
101        return 'object'
102
103
104def _get_build_extension():
105    dist = Distribution()
106    # Ensure the build respects distutils configuration by parsing
107    # the configuration files
108    config_files = dist.find_config_files()
109    dist.parse_config_files(config_files)
110    build_extension = build_ext(dist)
111    build_extension.finalize_options()
112    return build_extension
113
114
115@cached_function
116def _create_context(cython_include_dirs):
117    return Context(list(cython_include_dirs), default_options)
118
119
120_cython_inline_cache = {}
121_cython_inline_default_context = _create_context(('.',))
122
123def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None):
124    for symbol in unbound_symbols:
125        if symbol not in kwds:
126            if locals is None or globals is None:
127                calling_frame = inspect.currentframe().f_back.f_back.f_back
128                if locals is None:
129                    locals = calling_frame.f_locals
130                if globals is None:
131                    globals = calling_frame.f_globals
132            if symbol in locals:
133                kwds[symbol] = locals[symbol]
134            elif symbol in globals:
135                kwds[symbol] = globals[symbol]
136            else:
137                print("Couldn't find %r" % symbol)
138
139def cython_inline(code, get_type=unsafe_type,
140                  lib_dir=os.path.join(get_cython_cache_dir(), 'inline'),
141                  cython_include_dirs=None, cython_compiler_directives=None,
142                  force=False, quiet=False, locals=None, globals=None, language_level=None, **kwds):
143
144    if get_type is None:
145        get_type = lambda x: 'object'
146    ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context
147
148    # Fast path if this has been called in this session.
149    _unbound_symbols = _cython_inline_cache.get(code)
150    if _unbound_symbols is not None:
151        _populate_unbound(kwds, _unbound_symbols, locals, globals)
152        args = sorted(kwds.items())
153        arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args])
154        invoke = _cython_inline_cache.get((code, arg_sigs))
155        if invoke is not None:
156            arg_list = [arg[1] for arg in args]
157            return invoke(*arg_list)
158
159    orig_code = code
160    code = to_unicode(code)
161    code, literals = strip_string_literals(code)
162    code = strip_common_indent(code)
163    if locals is None:
164        locals = inspect.currentframe().f_back.f_back.f_locals
165    if globals is None:
166        globals = inspect.currentframe().f_back.f_back.f_globals
167    try:
168        _cython_inline_cache[orig_code] = _unbound_symbols = unbound_symbols(code)
169        _populate_unbound(kwds, _unbound_symbols, locals, globals)
170    except AssertionError:
171        if not quiet:
172            # Parsing from strings not fully supported (e.g. cimports).
173            print("Could not parse code as a string (to extract unbound symbols).")
174
175    cython_compiler_directives = dict(cython_compiler_directives or {})
176    if language_level is not None:
177        cython_compiler_directives['language_level'] = language_level
178
179    cimports = []
180    for name, arg in list(kwds.items()):
181        if arg is cython_module:
182            cimports.append('\ncimport cython as %s' % name)
183            del kwds[name]
184    arg_names = sorted(kwds)
185    arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
186    key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__
187    module_name = "_cython_inline_" + hashlib.md5(_unicode(key).encode('utf-8')).hexdigest()
188
189    if module_name in sys.modules:
190        module = sys.modules[module_name]
191
192    else:
193        build_extension = None
194        if cython_inline.so_ext is None:
195            # Figure out and cache current extension suffix
196            build_extension = _get_build_extension()
197            cython_inline.so_ext = build_extension.get_ext_filename('')
198
199        module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext)
200
201        if not os.path.exists(lib_dir):
202            os.makedirs(lib_dir)
203        if force or not os.path.isfile(module_path):
204            cflags = []
205            c_include_dirs = []
206            qualified = re.compile(r'([.\w]+)[.]')
207            for type, _ in arg_sigs:
208                m = qualified.match(type)
209                if m:
210                    cimports.append('\ncimport %s' % m.groups()[0])
211                    # one special case
212                    if m.groups()[0] == 'numpy':
213                        import numpy
214                        c_include_dirs.append(numpy.get_include())
215                        # cflags.append('-Wno-unused')
216            module_body, func_body = extract_func_code(code)
217            params = ', '.join(['%s %s' % a for a in arg_sigs])
218            module_code = """
219%(module_body)s
220%(cimports)s
221def __invoke(%(params)s):
222%(func_body)s
223    return locals()
224            """ % {'cimports': '\n'.join(cimports),
225                   'module_body': module_body,
226                   'params': params,
227                   'func_body': func_body }
228            for key, value in literals.items():
229                module_code = module_code.replace(key, value)
230            pyx_file = os.path.join(lib_dir, module_name + '.pyx')
231            fh = open(pyx_file, 'w')
232            try:
233                fh.write(module_code)
234            finally:
235                fh.close()
236            extension = Extension(
237                name = module_name,
238                sources = [pyx_file],
239                include_dirs = c_include_dirs,
240                extra_compile_args = cflags)
241            if build_extension is None:
242                build_extension = _get_build_extension()
243            build_extension.extensions = cythonize(
244                [extension],
245                include_path=cython_include_dirs or ['.'],
246                compiler_directives=cython_compiler_directives,
247                quiet=quiet)
248            build_extension.build_temp = os.path.dirname(pyx_file)
249            build_extension.build_lib  = lib_dir
250            build_extension.run()
251
252        module = imp.load_dynamic(module_name, module_path)
253
254    _cython_inline_cache[orig_code, arg_sigs] = module.__invoke
255    arg_list = [kwds[arg] for arg in arg_names]
256    return module.__invoke(*arg_list)
257
258# Cached suffix used by cython_inline above.  None should get
259# overridden with actual value upon the first cython_inline invocation
260cython_inline.so_ext = None
261
262_find_non_space = re.compile('[^ ]').search
263
264
265def strip_common_indent(code):
266    min_indent = None
267    lines = code.splitlines()
268    for line in lines:
269        match = _find_non_space(line)
270        if not match:
271            continue  # blank
272        indent = match.start()
273        if line[indent] == '#':
274            continue  # comment
275        if min_indent is None or min_indent > indent:
276            min_indent = indent
277    for ix, line in enumerate(lines):
278        match = _find_non_space(line)
279        if not match or not line or line[indent:indent+1] == '#':
280            continue
281        lines[ix] = line[min_indent:]
282    return '\n'.join(lines)
283
284
285module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
286def extract_func_code(code):
287    module = []
288    function = []
289    current = function
290    code = code.replace('\t', ' ')
291    lines = code.split('\n')
292    for line in lines:
293        if not line.startswith(' '):
294            if module_statement.match(line):
295                current = module
296            else:
297                current = function
298        current.append(line)
299    return '\n'.join(module), '    ' + '\n    '.join(function)
300
301
302try:
303    from inspect import getcallargs
304except ImportError:
305    def getcallargs(func, *arg_values, **kwd_values):
306        all = {}
307        args, varargs, kwds, defaults = inspect.getargspec(func)
308        if varargs is not None:
309            all[varargs] = arg_values[len(args):]
310        for name, value in zip(args, arg_values):
311            all[name] = value
312        for name, value in list(kwd_values.items()):
313            if name in args:
314                if name in all:
315                    raise TypeError("Duplicate argument %s" % name)
316                all[name] = kwd_values.pop(name)
317        if kwds is not None:
318            all[kwds] = kwd_values
319        elif kwd_values:
320            raise TypeError("Unexpected keyword arguments: %s" % list(kwd_values))
321        if defaults is None:
322            defaults = ()
323        first_default = len(args) - len(defaults)
324        for ix, name in enumerate(args):
325            if name not in all:
326                if ix >= first_default:
327                    all[name] = defaults[ix - first_default]
328                else:
329                    raise TypeError("Missing argument: %s" % name)
330        return all
331
332
333def get_body(source):
334    ix = source.index(':')
335    if source[:5] == 'lambda':
336        return "return %s" % source[ix+1:]
337    else:
338        return source[ix+1:]
339
340
341# Lots to be done here... It would be especially cool if compiled functions
342# could invoke each other quickly.
343class RuntimeCompiledFunction(object):
344
345    def __init__(self, f):
346        self._f = f
347        self._body = get_body(inspect.getsource(f))
348
349    def __call__(self, *args, **kwds):
350        all = getcallargs(self._f, *args, **kwds)
351        if IS_PY3:
352            return cython_inline(self._body, locals=self._f.__globals__, globals=self._f.__globals__, **all)
353        else:
354            return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)
355