1"""
2Cython -- Things that don't belong
3          anywhere else in particular
4"""
5
6from __future__ import absolute_import
7
8try:
9    from __builtin__ import basestring
10except ImportError:
11    basestring = str
12
13try:
14    FileNotFoundError
15except NameError:
16    FileNotFoundError = OSError
17
18import os
19import sys
20import re
21import io
22import codecs
23import glob
24import shutil
25import tempfile
26from contextlib import contextmanager
27
28from . import __version__ as cython_version
29
30PACKAGE_FILES = ("__init__.py", "__init__.pyc", "__init__.pyx", "__init__.pxd")
31
32_build_cache_name = "__{0}_cache".format
33_CACHE_NAME_PATTERN = re.compile(r"^__(.+)_cache$")
34
35modification_time = os.path.getmtime
36
37_function_caches = []
38
39
40def clear_function_caches():
41    for cache in _function_caches:
42        cache.clear()
43
44
45def cached_function(f):
46    cache = {}
47    _function_caches.append(cache)
48    uncomputed = object()
49
50    def wrapper(*args):
51        res = cache.get(args, uncomputed)
52        if res is uncomputed:
53            res = cache[args] = f(*args)
54        return res
55
56    wrapper.uncached = f
57    return wrapper
58
59
60def _find_cache_attributes(obj):
61    """The function iterates over the attributes of the object and,
62    if it finds the name of the cache, it returns it and the corresponding method name.
63    The method may not be present in the object.
64    """
65    for attr_name in dir(obj):
66        match = _CACHE_NAME_PATTERN.match(attr_name)
67        if match is not None:
68            yield attr_name, match.group(1)
69
70
71def clear_method_caches(obj):
72    """Removes every cache found in the object,
73    if a corresponding method exists for that cache.
74    """
75    for cache_name, method_name in _find_cache_attributes(obj):
76        if hasattr(obj, method_name):
77            delattr(obj, cache_name)
78        # if there is no corresponding method, then we assume
79        # that this attribute was not created by our cached method
80
81
82def cached_method(f):
83    cache_name = _build_cache_name(f.__name__)
84
85    def wrapper(self, *args):
86        cache = getattr(self, cache_name, None)
87        if cache is None:
88            cache = {}
89            setattr(self, cache_name, cache)
90        if args in cache:
91            return cache[args]
92        res = cache[args] = f(self, *args)
93        return res
94
95    return wrapper
96
97
98def replace_suffix(path, newsuf):
99    base, _ = os.path.splitext(path)
100    return base + newsuf
101
102
103def open_new_file(path):
104    if os.path.exists(path):
105        # Make sure to create a new file here so we can
106        # safely hard link the output files.
107        os.unlink(path)
108
109    # we use the ISO-8859-1 encoding here because we only write pure
110    # ASCII strings or (e.g. for file names) byte encoded strings as
111    # Unicode, so we need a direct mapping from the first 256 Unicode
112    # characters to a byte sequence, which ISO-8859-1 provides
113
114    # note: can't use io.open() in Py2 as we may be writing str objects
115    return codecs.open(path, "w", encoding="ISO-8859-1")
116
117
118def castrate_file(path, st):
119    #  Remove junk contents from an output file after a
120    #  failed compilation.
121    #  Also sets access and modification times back to
122    #  those specified by st (a stat struct).
123    if not is_cython_generated_file(path, allow_failed=True, if_not_found=False):
124        return
125
126    try:
127        f = open_new_file(path)
128    except EnvironmentError:
129        pass
130    else:
131        f.write(
132            "#error Do not use this file, it is the result of a failed Cython compilation.\n")
133        f.close()
134        if st:
135            os.utime(path, (st.st_atime, st.st_mtime-1))
136
137
138def is_cython_generated_file(path, allow_failed=False, if_not_found=True):
139    failure_marker = b"#error Do not use this file, it is the result of a failed Cython compilation."
140    file_content = None
141    if os.path.exists(path):
142        try:
143            with open(path, "rb") as f:
144                file_content = f.read(len(failure_marker))
145        except (OSError, IOError):
146            pass  # Probably just doesn't exist any more
147
148    if file_content is None:
149        # file does not exist (yet)
150        return if_not_found
151
152    return (
153        # Cython C file?
154        file_content.startswith(b"/* Generated by Cython ") or
155        # Cython output file after previous failures?
156        (allow_failed and file_content == failure_marker) or
157        # Let's allow overwriting empty files as well. They might have resulted from previous failures.
158        not file_content
159    )
160
161
162def file_newer_than(path, time):
163    ftime = modification_time(path)
164    return ftime > time
165
166
167def safe_makedirs(path):
168    try:
169        os.makedirs(path)
170    except OSError:
171        if not os.path.isdir(path):
172            raise
173
174
175def copy_file_to_dir_if_newer(sourcefile, destdir):
176    """
177    Copy file sourcefile to directory destdir (creating it if needed),
178    preserving metadata. If the destination file exists and is not
179    older than the source file, the copying is skipped.
180    """
181    destfile = os.path.join(destdir, os.path.basename(sourcefile))
182    try:
183        desttime = modification_time(destfile)
184    except OSError:
185        # New file does not exist, destdir may or may not exist
186        safe_makedirs(destdir)
187    else:
188        # New file already exists
189        if not file_newer_than(sourcefile, desttime):
190            return
191    shutil.copy2(sourcefile, destfile)
192
193
194@cached_function
195def find_root_package_dir(file_path):
196    dir = os.path.dirname(file_path)
197    if file_path == dir:
198        return dir
199    elif is_package_dir(dir):
200        return find_root_package_dir(dir)
201    else:
202        return dir
203
204
205@cached_function
206def check_package_dir(dir_path, package_names):
207    namespace = True
208    for dirname in package_names:
209        dir_path = os.path.join(dir_path, dirname)
210        has_init = contains_init(dir_path)
211        if has_init:
212            namespace = False
213    return dir_path, namespace
214
215
216@cached_function
217def contains_init(dir_path):
218    for filename in PACKAGE_FILES:
219        path = os.path.join(dir_path, filename)
220        if path_exists(path):
221            return 1
222
223
224def is_package_dir(dir_path):
225    if contains_init(dir_path):
226        return 1
227
228
229@cached_function
230def path_exists(path):
231    # try on the filesystem first
232    if os.path.exists(path):
233        return True
234    # figure out if a PEP 302 loader is around
235    try:
236        loader = __loader__
237        # XXX the code below assumes a 'zipimport.zipimporter' instance
238        # XXX should be easy to generalize, but too lazy right now to write it
239        archive_path = getattr(loader, 'archive', None)
240        if archive_path:
241            normpath = os.path.normpath(path)
242            if normpath.startswith(archive_path):
243                arcname = normpath[len(archive_path)+1:]
244                try:
245                    loader.get_data(arcname)
246                    return True
247                except IOError:
248                    return False
249    except NameError:
250        pass
251    return False
252
253
254_parse_file_version = re.compile(r".*[.]cython-([0-9]+)[.][^./\\]+$").findall
255
256
257@cached_function
258def find_versioned_file(directory, filename, suffix,
259                        _current_version=int(re.sub(r"^([0-9]+)[.]([0-9]+).*", r"\1\2", cython_version))):
260    """
261    Search a directory for versioned pxd files, e.g. "lib.cython-30.pxd" for a Cython 3.0+ version.
262
263    @param directory: the directory to search
264    @param filename: the filename without suffix
265    @param suffix: the filename extension including the dot, e.g. ".pxd"
266    @return: the file path if found, or None
267    """
268    assert not suffix or suffix[:1] == '.'
269    path_prefix = os.path.join(directory, filename)
270
271    matching_files = glob.glob(path_prefix + ".cython-*" + suffix)
272    path = path_prefix + suffix
273    if not os.path.exists(path):
274        path = None
275    best_match = (-1, path)  # last resort, if we do not have versioned .pxd files
276
277    for path in matching_files:
278        versions = _parse_file_version(path)
279        if versions:
280            int_version = int(versions[0])
281            # Let's assume no duplicates.
282            if best_match[0] < int_version <= _current_version:
283                best_match = (int_version, path)
284    return best_match[1]
285
286
287# file name encodings
288
289def decode_filename(filename):
290    if isinstance(filename, bytes):
291        try:
292            filename_encoding = sys.getfilesystemencoding()
293            if filename_encoding is None:
294                filename_encoding = sys.getdefaultencoding()
295            filename = filename.decode(filename_encoding)
296        except UnicodeDecodeError:
297            pass
298    return filename
299
300
301# support for source file encoding detection
302
303_match_file_encoding = re.compile(br"(\w*coding)[:=]\s*([-\w.]+)").search
304
305
306def detect_opened_file_encoding(f, default='UTF-8'):
307    # PEPs 263 and 3120
308    # Most of the time the first two lines fall in the first couple of hundred chars,
309    # and this bulk read/split is much faster.
310    lines = ()
311    start = b''
312    while len(lines) < 3:
313        data = f.read(500)
314        start += data
315        lines = start.split(b"\n")
316        if not data:
317            break
318
319    m = _match_file_encoding(lines[0])
320    if m and m.group(1) != b'c_string_encoding':
321        return m.group(2).decode('iso8859-1')
322    elif len(lines) > 1:
323        m = _match_file_encoding(lines[1])
324        if m:
325            return m.group(2).decode('iso8859-1')
326    return default
327
328
329def skip_bom(f):
330    """
331    Read past a BOM at the beginning of a source file.
332    This could be added to the scanner, but it's *substantially* easier
333    to keep it at this level.
334    """
335    if f.read(1) != u'\uFEFF':
336        f.seek(0)
337
338
339def open_source_file(source_filename, encoding=None, error_handling=None):
340    stream = None
341    try:
342        if encoding is None:
343            # Most of the time the encoding is not specified, so try hard to open the file only once.
344            f = io.open(source_filename, 'rb')
345            encoding = detect_opened_file_encoding(f)
346            f.seek(0)
347            stream = io.TextIOWrapper(f, encoding=encoding, errors=error_handling)
348        else:
349            stream = io.open(source_filename, encoding=encoding, errors=error_handling)
350
351    except OSError:
352        if os.path.exists(source_filename):
353            raise  # File is there, but something went wrong reading from it.
354        # Allow source files to be in zip files etc.
355        try:
356            loader = __loader__
357            if source_filename.startswith(loader.archive):
358                stream = open_source_from_loader(
359                    loader, source_filename,
360                    encoding, error_handling)
361        except (NameError, AttributeError):
362            pass
363
364    if stream is None:
365        raise FileNotFoundError(source_filename)
366    skip_bom(stream)
367    return stream
368
369
370def open_source_from_loader(loader,
371                            source_filename,
372                            encoding=None, error_handling=None):
373    nrmpath = os.path.normpath(source_filename)
374    arcname = nrmpath[len(loader.archive)+1:]
375    data = loader.get_data(arcname)
376    return io.TextIOWrapper(io.BytesIO(data),
377                            encoding=encoding,
378                            errors=error_handling)
379
380
381def str_to_number(value):
382    # note: this expects a string as input that was accepted by the
383    # parser already, with an optional "-" sign in front
384    is_neg = False
385    if value[:1] == '-':
386        is_neg = True
387        value = value[1:]
388    if len(value) < 2:
389        value = int(value, 0)
390    elif value[0] == '0':
391        literal_type = value[1]  # 0'o' - 0'b' - 0'x'
392        if literal_type in 'xX':
393            # hex notation ('0x1AF')
394            value = int(value[2:], 16)
395        elif literal_type in 'oO':
396            # Py3 octal notation ('0o136')
397            value = int(value[2:], 8)
398        elif literal_type in 'bB':
399            # Py3 binary notation ('0b101')
400            value = int(value[2:], 2)
401        else:
402            # Py2 octal notation ('0136')
403            value = int(value, 8)
404    else:
405        value = int(value, 0)
406    return -value if is_neg else value
407
408
409def long_literal(value):
410    if isinstance(value, basestring):
411        value = str_to_number(value)
412    return not -2**31 <= value < 2**31
413
414
415@cached_function
416def get_cython_cache_dir():
417    r"""
418    Return the base directory containing Cython's caches.
419
420    Priority:
421
422    1. CYTHON_CACHE_DIR
423    2. (OS X): ~/Library/Caches/Cython
424       (posix not OS X): XDG_CACHE_HOME/cython if XDG_CACHE_HOME defined
425    3. ~/.cython
426
427    """
428    if 'CYTHON_CACHE_DIR' in os.environ:
429        return os.environ['CYTHON_CACHE_DIR']
430
431    parent = None
432    if os.name == 'posix':
433        if sys.platform == 'darwin':
434            parent = os.path.expanduser('~/Library/Caches')
435        else:
436            # this could fallback on ~/.cache
437            parent = os.environ.get('XDG_CACHE_HOME')
438
439    if parent and os.path.isdir(parent):
440        return os.path.join(parent, 'cython')
441
442    # last fallback: ~/.cython
443    return os.path.expanduser(os.path.join('~', '.cython'))
444
445
446@contextmanager
447def captured_fd(stream=2, encoding=None):
448    orig_stream = os.dup(stream)  # keep copy of original stream
449    try:
450        with tempfile.TemporaryFile(mode="a+b") as temp_file:
451            def read_output(_output=[b'']):
452                if not temp_file.closed:
453                    temp_file.seek(0)
454                    _output[0] = temp_file.read()
455                return _output[0]
456
457            os.dup2(temp_file.fileno(), stream)  # replace stream by copy of pipe
458            try:
459                def get_output():
460                    result = read_output()
461                    return result.decode(encoding) if encoding else result
462
463                yield get_output
464            finally:
465                os.dup2(orig_stream, stream)  # restore original stream
466                read_output()  # keep the output in case it's used after closing the context manager
467    finally:
468        os.close(orig_stream)
469
470
471def get_encoding_candidates():
472    candidates = [sys.getdefaultencoding()]
473    for stream in (sys.stdout, sys.stdin, sys.__stdout__, sys.__stdin__):
474        encoding = getattr(stream, 'encoding', None)
475        # encoding might be None (e.g. somebody redirects stdout):
476        if encoding is not None and encoding not in candidates:
477            candidates.append(encoding)
478    return candidates
479
480
481def prepare_captured(captured):
482    captured_bytes = captured.strip()
483    if not captured_bytes:
484        return None
485    for encoding in get_encoding_candidates():
486        try:
487            return captured_bytes.decode(encoding)
488        except UnicodeDecodeError:
489            pass
490    # last resort: print at least the readable ascii parts correctly.
491    return captured_bytes.decode('latin-1')
492
493
494def print_captured(captured, output, header_line=None):
495    captured = prepare_captured(captured)
496    if captured:
497        if header_line:
498            output.write(header_line)
499        output.write(captured)
500
501
502def print_bytes(s, header_text=None, end=b'\n', file=sys.stdout, flush=True):
503    if header_text:
504        file.write(header_text)  # note: text! => file.write() instead of out.write()
505    file.flush()
506    try:
507        out = file.buffer  # Py3
508    except AttributeError:
509        out = file         # Py2
510    out.write(s)
511    if end:
512        out.write(end)
513    if flush:
514        out.flush()
515
516
517class LazyStr:
518    def __init__(self, callback):
519        self.callback = callback
520
521    def __str__(self):
522        return self.callback()
523
524    def __repr__(self):
525        return self.callback()
526
527    def __add__(self, right):
528        return self.callback() + right
529
530    def __radd__(self, left):
531        return left + self.callback()
532
533
534class OrderedSet(object):
535    def __init__(self, elements=()):
536        self._list = []
537        self._set = set()
538        self.update(elements)
539
540    def __iter__(self):
541        return iter(self._list)
542
543    def update(self, elements):
544        for e in elements:
545            self.add(e)
546
547    def add(self, e):
548        if e not in self._set:
549            self._list.append(e)
550            self._set.add(e)
551
552
553# Class decorator that adds a metaclass and recreates the class with it.
554# Copied from 'six'.
555def add_metaclass(metaclass):
556    """Class decorator for creating a class with a metaclass."""
557    def wrapper(cls):
558        orig_vars = cls.__dict__.copy()
559        slots = orig_vars.get('__slots__')
560        if slots is not None:
561            if isinstance(slots, str):
562                slots = [slots]
563            for slots_var in slots:
564                orig_vars.pop(slots_var)
565        orig_vars.pop('__dict__', None)
566        orig_vars.pop('__weakref__', None)
567        return metaclass(cls.__name__, cls.__bases__, orig_vars)
568    return wrapper
569
570
571def raise_error_if_module_name_forbidden(full_module_name):
572    # it is bad idea to call the pyx-file cython.pyx, so fail early
573    if full_module_name == 'cython' or full_module_name.startswith('cython.'):
574        raise ValueError('cython is a special module, cannot be used as a module name')
575
576
577def build_hex_version(version_string):
578    """
579    Parse and translate '4.3a1' into the readable hex representation '0x040300A1' (like PY_VERSION_HEX).
580    """
581    # First, parse '4.12a1' into [4, 12, 0, 0xA01].
582    digits = []
583    release_status = 0xF0
584    for digit in re.split('([.abrc]+)', version_string):
585        if digit in ('a', 'b', 'rc'):
586            release_status = {'a': 0xA0, 'b': 0xB0, 'rc': 0xC0}[digit]
587            digits = (digits + [0, 0])[:3]  # 1.2a1 -> 1.2.0a1
588        elif digit != '.':
589            digits.append(int(digit))
590    digits = (digits + [0] * 3)[:4]
591    digits[3] += release_status
592
593    # Then, build a single hex value, two hex digits per version part.
594    hexversion = 0
595    for digit in digits:
596        hexversion = (hexversion << 8) + digit
597
598    return '0x%08X' % hexversion
599