1#
2#   Cython -- Things that don't belong
3#            anywhere else in particular
4#
5
6from __future__ import absolute_import
7
8try:
9    from __builtin__ import basestring
10except ImportError:
11    basestring = str
12
13try:
14    FileNotFoundError
15except NameError:
16    FileNotFoundError = OSError
17
18import os
19import sys
20import re
21import io
22import codecs
23import shutil
24from contextlib import contextmanager
25
26modification_time = os.path.getmtime
27
28_function_caches = []
29def clear_function_caches():
30    for cache in _function_caches:
31        cache.clear()
32
33def cached_function(f):
34    cache = {}
35    _function_caches.append(cache)
36    uncomputed = object()
37    def wrapper(*args):
38        res = cache.get(args, uncomputed)
39        if res is uncomputed:
40            res = cache[args] = f(*args)
41        return res
42    wrapper.uncached = f
43    return wrapper
44
45def cached_method(f):
46    cache_name = '__%s_cache' % f.__name__
47    def wrapper(self, *args):
48        cache = getattr(self, cache_name, None)
49        if cache is None:
50            cache = {}
51            setattr(self, cache_name, cache)
52        if args in cache:
53            return cache[args]
54        res = cache[args] = f(self, *args)
55        return res
56    return wrapper
57
58def replace_suffix(path, newsuf):
59    base, _ = os.path.splitext(path)
60    return base + newsuf
61
62
63def open_new_file(path):
64    if os.path.exists(path):
65        # Make sure to create a new file here so we can
66        # safely hard link the output files.
67        os.unlink(path)
68
69    # we use the ISO-8859-1 encoding here because we only write pure
70    # ASCII strings or (e.g. for file names) byte encoded strings as
71    # Unicode, so we need a direct mapping from the first 256 Unicode
72    # characters to a byte sequence, which ISO-8859-1 provides
73
74    # note: can't use io.open() in Py2 as we may be writing str objects
75    return codecs.open(path, "w", encoding="ISO-8859-1")
76
77
78def castrate_file(path, st):
79    #  Remove junk contents from an output file after a
80    #  failed compilation.
81    #  Also sets access and modification times back to
82    #  those specified by st (a stat struct).
83    try:
84        f = open_new_file(path)
85    except EnvironmentError:
86        pass
87    else:
88        f.write(
89            "#error Do not use this file, it is the result of a failed Cython compilation.\n")
90        f.close()
91        if st:
92            os.utime(path, (st.st_atime, st.st_mtime-1))
93
94def file_newer_than(path, time):
95    ftime = modification_time(path)
96    return ftime > time
97
98
99def safe_makedirs(path):
100    try:
101        os.makedirs(path)
102    except OSError:
103        if not os.path.isdir(path):
104            raise
105
106
107def copy_file_to_dir_if_newer(sourcefile, destdir):
108    """
109    Copy file sourcefile to directory destdir (creating it if needed),
110    preserving metadata. If the destination file exists and is not
111    older than the source file, the copying is skipped.
112    """
113    destfile = os.path.join(destdir, os.path.basename(sourcefile))
114    try:
115        desttime = modification_time(destfile)
116    except OSError:
117        # New file does not exist, destdir may or may not exist
118        safe_makedirs(destdir)
119    else:
120        # New file already exists
121        if not file_newer_than(sourcefile, desttime):
122            return
123    shutil.copy2(sourcefile, destfile)
124
125
126@cached_function
127def find_root_package_dir(file_path):
128    dir = os.path.dirname(file_path)
129    if file_path == dir:
130        return dir
131    elif is_package_dir(dir):
132        return find_root_package_dir(dir)
133    else:
134        return dir
135
136@cached_function
137def check_package_dir(dir, package_names):
138    for dirname in package_names:
139        dir = os.path.join(dir, dirname)
140        if not is_package_dir(dir):
141            return None
142    return dir
143
144@cached_function
145def is_package_dir(dir_path):
146    for filename in ("__init__.py",
147                     "__init__.pyc",
148                     "__init__.pyx",
149                     "__init__.pxd"):
150        path = os.path.join(dir_path, filename)
151        if path_exists(path):
152            return 1
153
154@cached_function
155def path_exists(path):
156    # try on the filesystem first
157    if os.path.exists(path):
158        return True
159    # figure out if a PEP 302 loader is around
160    try:
161        loader = __loader__
162        # XXX the code below assumes a 'zipimport.zipimporter' instance
163        # XXX should be easy to generalize, but too lazy right now to write it
164        archive_path = getattr(loader, 'archive', None)
165        if archive_path:
166            normpath = os.path.normpath(path)
167            if normpath.startswith(archive_path):
168                arcname = normpath[len(archive_path)+1:]
169                try:
170                    loader.get_data(arcname)
171                    return True
172                except IOError:
173                    return False
174    except NameError:
175        pass
176    return False
177
178# file name encodings
179
180def decode_filename(filename):
181    if isinstance(filename, bytes):
182        try:
183            filename_encoding = sys.getfilesystemencoding()
184            if filename_encoding is None:
185                filename_encoding = sys.getdefaultencoding()
186            filename = filename.decode(filename_encoding)
187        except UnicodeDecodeError:
188            pass
189    return filename
190
191# support for source file encoding detection
192
193_match_file_encoding = re.compile(br"(\w*coding)[:=]\s*([-\w.]+)").search
194
195
196def detect_opened_file_encoding(f):
197    # PEPs 263 and 3120
198    # Most of the time the first two lines fall in the first couple of hundred chars,
199    # and this bulk read/split is much faster.
200    lines = ()
201    start = b''
202    while len(lines) < 3:
203        data = f.read(500)
204        start += data
205        lines = start.split(b"\n")
206        if not data:
207            break
208    m = _match_file_encoding(lines[0])
209    if m and m.group(1) != b'c_string_encoding':
210        return m.group(2).decode('iso8859-1')
211    elif len(lines) > 1:
212        m = _match_file_encoding(lines[1])
213        if m:
214            return m.group(2).decode('iso8859-1')
215    return "UTF-8"
216
217
218def skip_bom(f):
219    """
220    Read past a BOM at the beginning of a source file.
221    This could be added to the scanner, but it's *substantially* easier
222    to keep it at this level.
223    """
224    if f.read(1) != u'\uFEFF':
225        f.seek(0)
226
227
228def open_source_file(source_filename, encoding=None, error_handling=None):
229    stream = None
230    try:
231        if encoding is None:
232            # Most of the time the encoding is not specified, so try hard to open the file only once.
233            f = io.open(source_filename, 'rb')
234            encoding = detect_opened_file_encoding(f)
235            f.seek(0)
236            stream = io.TextIOWrapper(f, encoding=encoding, errors=error_handling)
237        else:
238            stream = io.open(source_filename, encoding=encoding, errors=error_handling)
239
240    except OSError:
241        if os.path.exists(source_filename):
242            raise  # File is there, but something went wrong reading from it.
243        # Allow source files to be in zip files etc.
244        try:
245            loader = __loader__
246            if source_filename.startswith(loader.archive):
247                stream = open_source_from_loader(
248                    loader, source_filename,
249                    encoding, error_handling)
250        except (NameError, AttributeError):
251            pass
252
253    if stream is None:
254        raise FileNotFoundError(source_filename)
255    skip_bom(stream)
256    return stream
257
258
259def open_source_from_loader(loader,
260                            source_filename,
261                            encoding=None, error_handling=None):
262    nrmpath = os.path.normpath(source_filename)
263    arcname = nrmpath[len(loader.archive)+1:]
264    data = loader.get_data(arcname)
265    return io.TextIOWrapper(io.BytesIO(data),
266                            encoding=encoding,
267                            errors=error_handling)
268
269
270def str_to_number(value):
271    # note: this expects a string as input that was accepted by the
272    # parser already, with an optional "-" sign in front
273    is_neg = False
274    if value[:1] == '-':
275        is_neg = True
276        value = value[1:]
277    if len(value) < 2:
278        value = int(value, 0)
279    elif value[0] == '0':
280        literal_type = value[1]  # 0'o' - 0'b' - 0'x'
281        if literal_type in 'xX':
282            # hex notation ('0x1AF')
283            value = int(value[2:], 16)
284        elif literal_type in 'oO':
285            # Py3 octal notation ('0o136')
286            value = int(value[2:], 8)
287        elif literal_type in 'bB':
288            # Py3 binary notation ('0b101')
289            value = int(value[2:], 2)
290        else:
291            # Py2 octal notation ('0136')
292            value = int(value, 8)
293    else:
294        value = int(value, 0)
295    return -value if is_neg else value
296
297
298def long_literal(value):
299    if isinstance(value, basestring):
300        value = str_to_number(value)
301    return not -2**31 <= value < 2**31
302
303
304@cached_function
305def get_cython_cache_dir():
306    r"""
307    Return the base directory containing Cython's caches.
308
309    Priority:
310
311    1. CYTHON_CACHE_DIR
312    2. (OS X): ~/Library/Caches/Cython
313       (posix not OS X): XDG_CACHE_HOME/cython if XDG_CACHE_HOME defined
314    3. ~/.cython
315
316    """
317    if 'CYTHON_CACHE_DIR' in os.environ:
318        return os.environ['CYTHON_CACHE_DIR']
319
320    parent = None
321    if os.name == 'posix':
322        if sys.platform == 'darwin':
323            parent = os.path.expanduser('~/Library/Caches')
324        else:
325            # this could fallback on ~/.cache
326            parent = os.environ.get('XDG_CACHE_HOME')
327
328    if parent and os.path.isdir(parent):
329        return os.path.join(parent, 'cython')
330
331    # last fallback: ~/.cython
332    return os.path.expanduser(os.path.join('~', '.cython'))
333
334
335@contextmanager
336def captured_fd(stream=2, encoding=None):
337    pipe_in = t = None
338    orig_stream = os.dup(stream)  # keep copy of original stream
339    try:
340        pipe_in, pipe_out = os.pipe()
341        os.dup2(pipe_out, stream)  # replace stream by copy of pipe
342        try:
343            os.close(pipe_out)  # close original pipe-out stream
344            data = []
345
346            def copy():
347                try:
348                    while True:
349                        d = os.read(pipe_in, 1000)
350                        if d:
351                            data.append(d)
352                        else:
353                            break
354                finally:
355                    os.close(pipe_in)
356
357            def get_output():
358                output = b''.join(data)
359                if encoding:
360                    output = output.decode(encoding)
361                return output
362
363            from threading import Thread
364            t = Thread(target=copy)
365            t.daemon = True  # just in case
366            t.start()
367            yield get_output
368        finally:
369            os.dup2(orig_stream, stream)  # restore original stream
370            if t is not None:
371                t.join()
372    finally:
373        os.close(orig_stream)
374
375
376def print_bytes(s, header_text=None, end=b'\n', file=sys.stdout, flush=True):
377    if header_text:
378        file.write(header_text)  # note: text! => file.write() instead of out.write()
379    file.flush()
380    try:
381        out = file.buffer  # Py3
382    except AttributeError:
383        out = file         # Py2
384    out.write(s)
385    if end:
386        out.write(end)
387    if flush:
388        out.flush()
389
390class LazyStr:
391    def __init__(self, callback):
392        self.callback = callback
393    def __str__(self):
394        return self.callback()
395    def __repr__(self):
396        return self.callback()
397    def __add__(self, right):
398        return self.callback() + right
399    def __radd__(self, left):
400        return left + self.callback()
401
402
403class OrderedSet(object):
404  def __init__(self, elements=()):
405    self._list = []
406    self._set = set()
407    self.update(elements)
408  def __iter__(self):
409    return iter(self._list)
410  def update(self, elements):
411    for e in elements:
412      self.add(e)
413  def add(self, e):
414    if e not in self._set:
415      self._list.append(e)
416      self._set.add(e)
417
418
419# Class decorator that adds a metaclass and recreates the class with it.
420# Copied from 'six'.
421def add_metaclass(metaclass):
422    """Class decorator for creating a class with a metaclass."""
423    def wrapper(cls):
424        orig_vars = cls.__dict__.copy()
425        slots = orig_vars.get('__slots__')
426        if slots is not None:
427            if isinstance(slots, str):
428                slots = [slots]
429            for slots_var in slots:
430                orig_vars.pop(slots_var)
431        orig_vars.pop('__dict__', None)
432        orig_vars.pop('__weakref__', None)
433        return metaclass(cls.__name__, cls.__bases__, orig_vars)
434    return wrapper
435
436
437def raise_error_if_module_name_forbidden(full_module_name):
438    #it is bad idea to call the pyx-file cython.pyx, so fail early
439    if full_module_name == 'cython' or full_module_name.startswith('cython.'):
440        raise ValueError('cython is a special module, cannot be used as a module name')
441
442
443def build_hex_version(version_string):
444    """
445    Parse and translate '4.3a1' into the readable hex representation '0x040300A1' (like PY_HEX_VERSION).
446    """
447    # First, parse '4.12a1' into [4, 12, 0, 0xA01].
448    digits = []
449    release_status = 0xF0
450    for digit in re.split('([.abrc]+)', version_string):
451        if digit in ('a', 'b', 'rc'):
452            release_status = {'a': 0xA0, 'b': 0xB0, 'rc': 0xC0}[digit]
453            digits = (digits + [0, 0])[:3]  # 1.2a1 -> 1.2.0a1
454        elif digit != '.':
455            digits.append(int(digit))
456    digits = (digits + [0] * 3)[:4]
457    digits[3] += release_status
458
459    # Then, build a single hex value, two hex digits per version part.
460    hexversion = 0
461    for digit in digits:
462        hexversion = (hexversion << 8) + digit
463
464    return '0x%08X' % hexversion
465