1# util/compat.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Handle Python version/platform incompatibilities."""
9
10import collections
11import contextlib
12import inspect
13import operator
14import platform
15import sys
16
17
18py36 = sys.version_info >= (3, 6)
19py33 = sys.version_info >= (3, 3)
20py35 = sys.version_info >= (3, 5)
21py32 = sys.version_info >= (3, 2)
22py3k = sys.version_info >= (3, 0)
23py2k = sys.version_info < (3, 0)
24py265 = sys.version_info >= (2, 6, 5)
25jython = sys.platform.startswith("java")
26pypy = hasattr(sys, "pypy_version_info")
27
28win32 = sys.platform.startswith("win")
29osx = sys.platform.startswith("darwin")
30cpython = not pypy and not jython  # TODO: something better for this ?
31arm = "aarch" in platform.machine().lower()
32
33
34contextmanager = contextlib.contextmanager
35dottedgetter = operator.attrgetter
36namedtuple = collections.namedtuple
37next = next  # noqa
38
39FullArgSpec = collections.namedtuple(
40    "FullArgSpec",
41    [
42        "args",
43        "varargs",
44        "varkw",
45        "defaults",
46        "kwonlyargs",
47        "kwonlydefaults",
48        "annotations",
49    ],
50)
51
52try:
53    import threading
54except ImportError:
55    import dummy_threading as threading  # noqa
56
57
58# work around http://bugs.python.org/issue2646
59if py265:
60    safe_kwarg = lambda arg: arg  # noqa
61else:
62    safe_kwarg = str
63
64
65def inspect_getfullargspec(func):
66    """Fully vendored version of getfullargspec from Python 3.3."""
67
68    if inspect.ismethod(func):
69        func = func.__func__
70    if not inspect.isfunction(func):
71        raise TypeError("{!r} is not a Python function".format(func))
72
73    co = func.__code__
74    if not inspect.iscode(co):
75        raise TypeError("{!r} is not a code object".format(co))
76
77    nargs = co.co_argcount
78    names = co.co_varnames
79    nkwargs = co.co_kwonlyargcount if py3k else 0
80    args = list(names[:nargs])
81    kwonlyargs = list(names[nargs : nargs + nkwargs])
82
83    nargs += nkwargs
84    varargs = None
85    if co.co_flags & inspect.CO_VARARGS:
86        varargs = co.co_varnames[nargs]
87        nargs = nargs + 1
88    varkw = None
89    if co.co_flags & inspect.CO_VARKEYWORDS:
90        varkw = co.co_varnames[nargs]
91
92    return FullArgSpec(
93        args,
94        varargs,
95        varkw,
96        func.__defaults__,
97        kwonlyargs,
98        func.__kwdefaults__ if py3k else None,
99        func.__annotations__ if py3k else {},
100    )
101
102
103if py3k:
104    import base64
105    import builtins
106    import configparser
107    import itertools
108    import pickle
109
110    from functools import reduce
111    from io import BytesIO as byte_buffer
112    from io import StringIO
113    from itertools import zip_longest
114    from urllib.parse import (
115        quote_plus,
116        unquote_plus,
117        parse_qsl,
118        quote,
119        unquote,
120    )
121
122    string_types = (str,)
123    binary_types = (bytes,)
124    binary_type = bytes
125    text_type = str
126    int_types = (int,)
127    iterbytes = iter
128
129    itertools_filterfalse = itertools.filterfalse
130    itertools_filter = filter
131    itertools_imap = map
132
133    exec_ = getattr(builtins, "exec")
134    import_ = getattr(builtins, "__import__")
135    print_ = getattr(builtins, "print")
136
137    def b(s):
138        return s.encode("latin-1")
139
140    def b64decode(x):
141        return base64.b64decode(x.encode("ascii"))
142
143    def b64encode(x):
144        return base64.b64encode(x).decode("ascii")
145
146    def decode_backslashreplace(text, encoding):
147        return text.decode(encoding, errors="backslashreplace")
148
149    def cmp(a, b):
150        return (a > b) - (a < b)
151
152    def raise_(
153        exception, with_traceback=None, replace_context=None, from_=False
154    ):
155        r"""implement "raise" with cause support.
156
157        :param exception: exception to raise
158        :param with_traceback: will call exception.with_traceback()
159        :param replace_context: an as-yet-unsupported feature.  This is
160         an exception object which we are "replacing", e.g., it's our
161         "cause" but we don't want it printed.    Basically just what
162         ``__suppress_context__`` does but we don't want to suppress
163         the enclosing context, if any.  So for now we make it the
164         cause.
165        :param from\_: the cause.  this actually sets the cause and doesn't
166         hope to hide it someday.
167
168        """
169        if with_traceback is not None:
170            exception = exception.with_traceback(with_traceback)
171
172        if from_ is not False:
173            exception.__cause__ = from_
174        elif replace_context is not None:
175            # no good solution here, we would like to have the exception
176            # have only the context of replace_context.__context__ so that the
177            # intermediary exception does not change, but we can't figure
178            # that out.
179            exception.__cause__ = replace_context
180
181        try:
182            raise exception
183        finally:
184            # credit to
185            # https://cosmicpercolator.com/2016/01/13/exception-leaks-in-python-2-and-3/
186            # as the __traceback__ object creates a cycle
187            del exception, replace_context, from_, with_traceback
188
189    from typing import TYPE_CHECKING
190
191    def u(s):
192        return s
193
194    def ue(s):
195        return s
196
197    if py32:
198        callable = callable  # noqa
199    else:
200
201        def callable(fn):  # noqa
202            return hasattr(fn, "__call__")
203
204
205else:
206    import base64
207    import ConfigParser as configparser  # noqa
208    import itertools
209
210    from StringIO import StringIO  # noqa
211    from cStringIO import StringIO as byte_buffer  # noqa
212    from itertools import izip_longest as zip_longest  # noqa
213    from urllib import quote  # noqa
214    from urllib import quote_plus  # noqa
215    from urllib import unquote  # noqa
216    from urllib import unquote_plus  # noqa
217    from urlparse import parse_qsl  # noqa
218
219    try:
220        import cPickle as pickle
221    except ImportError:
222        import pickle  # noqa
223
224    string_types = (basestring,)  # noqa
225    binary_types = (bytes,)
226    binary_type = str
227    text_type = unicode  # noqa
228    int_types = int, long  # noqa
229
230    callable = callable  # noqa
231    cmp = cmp  # noqa
232    reduce = reduce  # noqa
233
234    b64encode = base64.b64encode
235    b64decode = base64.b64decode
236
237    itertools_filterfalse = itertools.ifilterfalse
238    itertools_filter = itertools.ifilter
239    itertools_imap = itertools.imap
240
241    def b(s):
242        return s
243
244    def exec_(func_text, globals_, lcl=None):
245        if lcl is None:
246            exec("exec func_text in globals_")
247        else:
248            exec("exec func_text in globals_, lcl")
249
250    def iterbytes(buf):
251        return (ord(byte) for byte in buf)
252
253    def import_(*args):
254        if len(args) == 4:
255            args = args[0:3] + ([str(arg) for arg in args[3]],)
256        return __import__(*args)
257
258    def print_(*args, **kwargs):
259        fp = kwargs.pop("file", sys.stdout)
260        if fp is None:
261            return
262        for arg in enumerate(args):
263            if not isinstance(arg, basestring):  # noqa
264                arg = str(arg)
265            fp.write(arg)
266
267    def u(s):
268        # this differs from what six does, which doesn't support non-ASCII
269        # strings - we only use u() with
270        # literal source strings, and all our source files with non-ascii
271        # in them (all are tests) are utf-8 encoded.
272        return unicode(s, "utf-8")  # noqa
273
274    def ue(s):
275        return unicode(s, "unicode_escape")  # noqa
276
277    def decode_backslashreplace(text, encoding):
278        try:
279            return text.decode(encoding)
280        except UnicodeDecodeError:
281            # regular "backslashreplace" for an incompatible encoding raises:
282            # "TypeError: don't know how to handle UnicodeDecodeError in
283            # error callback"
284            return repr(text)[1:-1].decode()
285
286    def safe_bytestring(text):
287        # py2k only
288        if not isinstance(text, string_types):
289            return unicode(text).encode("ascii", errors="backslashreplace")
290        elif isinstance(text, unicode):
291            return text.encode("ascii", errors="backslashreplace")
292        else:
293            return text
294
295    exec(
296        "def raise_(exception, with_traceback=None, replace_context=None, "
297        "from_=False):\n"
298        "    if with_traceback:\n"
299        "        raise type(exception), exception, with_traceback\n"
300        "    else:\n"
301        "        raise exception\n"
302    )
303
304    TYPE_CHECKING = False
305
306if py35:
307
308    def _formatannotation(annotation, base_module=None):
309        """vendored from python 3.7"""
310
311        if getattr(annotation, "__module__", None) == "typing":
312            return repr(annotation).replace("typing.", "")
313        if isinstance(annotation, type):
314            if annotation.__module__ in ("builtins", base_module):
315                return annotation.__qualname__
316            return annotation.__module__ + "." + annotation.__qualname__
317        return repr(annotation)
318
319    def inspect_formatargspec(
320        args,
321        varargs=None,
322        varkw=None,
323        defaults=None,
324        kwonlyargs=(),
325        kwonlydefaults={},
326        annotations={},
327        formatarg=str,
328        formatvarargs=lambda name: "*" + name,
329        formatvarkw=lambda name: "**" + name,
330        formatvalue=lambda value: "=" + repr(value),
331        formatreturns=lambda text: " -> " + text,
332        formatannotation=_formatannotation,
333    ):
334        """Copy formatargspec from python 3.7 standard library.
335
336        Python 3 has deprecated formatargspec and requested that Signature
337        be used instead, however this requires a full reimplementation
338        of formatargspec() in terms of creating Parameter objects and such.
339        Instead of introducing all the object-creation overhead and having
340        to reinvent from scratch, just copy their compatibility routine.
341
342        Utimately we would need to rewrite our "decorator" routine completely
343        which is not really worth it right now, until all Python 2.x support
344        is dropped.
345
346        """
347
348        def formatargandannotation(arg):
349            result = formatarg(arg)
350            if arg in annotations:
351                result += ": " + formatannotation(annotations[arg])
352            return result
353
354        specs = []
355        if defaults:
356            firstdefault = len(args) - len(defaults)
357        for i, arg in enumerate(args):
358            spec = formatargandannotation(arg)
359            if defaults and i >= firstdefault:
360                spec = spec + formatvalue(defaults[i - firstdefault])
361            specs.append(spec)
362
363        if varargs is not None:
364            specs.append(formatvarargs(formatargandannotation(varargs)))
365        else:
366            if kwonlyargs:
367                specs.append("*")
368
369        if kwonlyargs:
370            for kwonlyarg in kwonlyargs:
371                spec = formatargandannotation(kwonlyarg)
372                if kwonlydefaults and kwonlyarg in kwonlydefaults:
373                    spec += formatvalue(kwonlydefaults[kwonlyarg])
374                specs.append(spec)
375
376        if varkw is not None:
377            specs.append(formatvarkw(formatargandannotation(varkw)))
378
379        result = "(" + ", ".join(specs) + ")"
380        if "return" in annotations:
381            result += formatreturns(formatannotation(annotations["return"]))
382        return result
383
384
385elif py2k:
386    from inspect import formatargspec as _inspect_formatargspec
387
388    def inspect_formatargspec(*spec, **kw):
389        # convert for a potential FullArgSpec from compat.getfullargspec()
390        return _inspect_formatargspec(*spec[0:4], **kw)  # noqa
391
392
393else:
394    from inspect import formatargspec as inspect_formatargspec  # noqa
395
396
397# Fix deprecation of accessing ABCs straight from collections module
398# (which will stop working in 3.8).
399if py33:
400    import collections.abc as collections_abc
401else:
402    import collections as collections_abc  # noqa
403
404
405@contextlib.contextmanager
406def nested(*managers):
407    """Implement contextlib.nested, mostly for unit tests.
408
409    As tests still need to run on py2.6 we can't use multiple-with yet.
410
411    Function is removed in py3k but also emits deprecation warning in 2.7
412    so just roll it here for everyone.
413
414    """
415
416    exits = []
417    vars_ = []
418    exc = (None, None, None)
419    try:
420        for mgr in managers:
421            exit_ = mgr.__exit__
422            enter = mgr.__enter__
423            vars_.append(enter())
424            exits.append(exit_)
425        yield vars_
426    except:
427        exc = sys.exc_info()
428    finally:
429        while exits:
430            exit_ = exits.pop()  # noqa
431            try:
432                if exit_(*exc):
433                    exc = (None, None, None)
434            except:
435                exc = sys.exc_info()
436        if exc != (None, None, None):
437            reraise(exc[0], exc[1], exc[2])
438
439
440def raise_from_cause(exception, exc_info=None):
441    r"""legacy.  use raise\_()"""
442
443    if exc_info is None:
444        exc_info = sys.exc_info()
445    exc_type, exc_value, exc_tb = exc_info
446    cause = exc_value if exc_value is not exception else None
447    reraise(type(exception), exception, tb=exc_tb, cause=cause)
448
449
450def reraise(tp, value, tb=None, cause=None):
451    r"""legacy.  use raise\_()"""
452
453    raise_(value, with_traceback=tb, from_=cause)
454
455
456def with_metaclass(meta, *bases):
457    """Create a base class with a metaclass.
458
459    Drops the middle class upon creation.
460
461    Source: http://lucumr.pocoo.org/2013/5/21/porting-to-python-3-redux/
462
463    """
464
465    class metaclass(meta):
466        __call__ = type.__call__
467        __init__ = type.__init__
468
469        def __new__(cls, name, this_bases, d):
470            if this_bases is None:
471                return type.__new__(cls, name, (), d)
472            return meta(name, bases, d)
473
474    return metaclass("temporary_class", None, {})
475
476
477if py3k:
478    from datetime import timezone
479else:
480    from datetime import datetime
481    from datetime import timedelta
482    from datetime import tzinfo
483
484    class timezone(tzinfo):
485        """Minimal port of python 3 timezone object"""
486
487        __slots__ = "_offset"
488
489        def __init__(self, offset):
490            if not isinstance(offset, timedelta):
491                raise TypeError("offset must be a timedelta")
492            if not self._minoffset <= offset <= self._maxoffset:
493                raise ValueError(
494                    "offset must be a timedelta "
495                    "strictly between -timedelta(hours=24) and "
496                    "timedelta(hours=24)."
497                )
498            self._offset = offset
499
500        def __eq__(self, other):
501            if type(other) != timezone:
502                return False
503            return self._offset == other._offset
504
505        def __hash__(self):
506            return hash(self._offset)
507
508        def __repr__(self):
509            return "sqlalchemy.util.%s(%r)" % (
510                self.__class__.__name__,
511                self._offset,
512            )
513
514        def __str__(self):
515            return self.tzname(None)
516
517        def utcoffset(self, dt):
518            return self._offset
519
520        def tzname(self, dt):
521            return self._name_from_offset(self._offset)
522
523        def dst(self, dt):
524            return None
525
526        def fromutc(self, dt):
527            if isinstance(dt, datetime):
528                if dt.tzinfo is not self:
529                    raise ValueError("fromutc: dt.tzinfo " "is not self")
530                return dt + self._offset
531            raise TypeError(
532                "fromutc() argument must be a datetime instance" " or None"
533            )
534
535        @staticmethod
536        def _timedelta_to_microseconds(timedelta):
537            """backport of timedelta._to_microseconds()"""
538            return (
539                timedelta.days * (24 * 3600) + timedelta.seconds
540            ) * 1000000 + timedelta.microseconds
541
542        @staticmethod
543        def _divmod_timedeltas(a, b):
544            """backport of timedelta.__divmod__"""
545
546            q, r = divmod(
547                timezone._timedelta_to_microseconds(a),
548                timezone._timedelta_to_microseconds(b),
549            )
550            return q, timedelta(0, 0, r)
551
552        @staticmethod
553        def _name_from_offset(delta):
554            if not delta:
555                return "UTC"
556            if delta < timedelta(0):
557                sign = "-"
558                delta = -delta
559            else:
560                sign = "+"
561            hours, rest = timezone._divmod_timedeltas(
562                delta, timedelta(hours=1)
563            )
564            minutes, rest = timezone._divmod_timedeltas(
565                rest, timedelta(minutes=1)
566            )
567            result = "UTC%s%02d:%02d" % (sign, hours, minutes)
568            if rest.seconds:
569                result += ":%02d" % (rest.seconds,)
570            if rest.microseconds:
571                result += ".%06d" % (rest.microseconds,)
572            return result
573
574        _maxoffset = timedelta(hours=23, minutes=59)
575        _minoffset = -_maxoffset
576
577    timezone.utc = timezone(timedelta(0))
578