1# testutils.py - utility module for psycopg2 testing.
2
3#
4# Copyright (C) 2010-2019 Daniele Varrazzo  <daniele.varrazzo@gmail.com>
5# Copyright (C) 2020-2021 The Psycopg Team
6#
7# psycopg2 is free software: you can redistribute it and/or modify it
8# under the terms of the GNU Lesser General Public License as published
9# by the Free Software Foundation, either version 3 of the License, or
10# (at your option) any later version.
11#
12# In addition, as a special exception, the copyright holders give
13# permission to link this program with the OpenSSL library (or with
14# modified versions of OpenSSL that use the same license as OpenSSL),
15# and distribute linked combinations including the two.
16#
17# You must obey the GNU Lesser General Public License in all respects for
18# all of the code used other than OpenSSL.
19#
20# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
21# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
22# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
23# License for more details.
24
25
26import re
27import os
28import sys
29import types
30import ctypes
31import select
32import operator
33import platform
34import unittest
35from functools import wraps
36from ctypes.util import find_library
37from io import StringIO         # noqa
38from io import TextIOBase       # noqa
39from importlib import reload    # noqa
40
41import psycopg2
42import psycopg2.errors
43import psycopg2.extensions
44
45from .testconfig import green, dsn, repl_dsn
46
47
48# Silence warnings caused by the stubbornness of the Python unittest
49# maintainers
50# https://bugs.python.org/issue9424
51if (not hasattr(unittest.TestCase, 'assert_')
52        or unittest.TestCase.assert_ is not unittest.TestCase.assertTrue):
53    # mavaff...
54    unittest.TestCase.assert_ = unittest.TestCase.assertTrue
55    unittest.TestCase.failUnless = unittest.TestCase.assertTrue
56    unittest.TestCase.assertEquals = unittest.TestCase.assertEqual
57    unittest.TestCase.failUnlessEqual = unittest.TestCase.assertEqual
58
59
60def assertDsnEqual(self, dsn1, dsn2, msg=None):
61    """Check that two conninfo string have the same content"""
62    self.assertEqual(set(dsn1.split()), set(dsn2.split()), msg)
63
64
65unittest.TestCase.assertDsnEqual = assertDsnEqual
66
67
68class ConnectingTestCase(unittest.TestCase):
69    """A test case providing connections for tests.
70
71    A connection for the test is always available as `self.conn`. Others can be
72    created with `self.connect()`. All are closed on tearDown.
73
74    Subclasses needing to customize setUp and tearDown should remember to call
75    the base class implementations.
76    """
77    def setUp(self):
78        self._conns = []
79
80    def tearDown(self):
81        # close the connections used in the test
82        for conn in self._conns:
83            if not conn.closed:
84                conn.close()
85
86    def assertQuotedEqual(self, first, second, msg=None):
87        """Compare two quoted strings disregarding eventual E'' quotes"""
88        def f(s):
89            if isinstance(s, str):
90                return re.sub(r"\bE'", "'", s)
91            elif isinstance(first, bytes):
92                return re.sub(br"\bE'", b"'", s)
93            else:
94                return s
95
96        return self.assertEqual(f(first), f(second), msg)
97
98    def connect(self, **kwargs):
99        try:
100            self._conns
101        except AttributeError as e:
102            raise AttributeError(
103                f"{e} (did you forget to call ConnectingTestCase.setUp()?)")
104
105        if 'dsn' in kwargs:
106            conninfo = kwargs.pop('dsn')
107        else:
108            conninfo = dsn
109        conn = psycopg2.connect(conninfo, **kwargs)
110        self._conns.append(conn)
111        return conn
112
113    def repl_connect(self, **kwargs):
114        """Return a connection set up for replication
115
116        The connection is on "PSYCOPG2_TEST_REPL_DSN" unless overridden by
117        a *dsn* kwarg.
118
119        Should raise a skip test if not available, but guard for None on
120        old Python versions.
121        """
122        if repl_dsn is None:
123            return self.skipTest("replication tests disabled by default")
124
125        if 'dsn' not in kwargs:
126            kwargs['dsn'] = repl_dsn
127        try:
128            conn = self.connect(**kwargs)
129            if conn.async_ == 1:
130                self.wait(conn)
131        except psycopg2.OperationalError as e:
132            # If pgcode is not set it is a genuine connection error
133            # Otherwise we tried to run some bad operation in the connection
134            # (e.g. bug #482) and we'd rather know that.
135            if e.pgcode is None:
136                return self.skipTest(f"replication db not configured: {e}")
137            else:
138                raise
139
140        return conn
141
142    def _get_conn(self):
143        if not hasattr(self, '_the_conn'):
144            self._the_conn = self.connect()
145
146        return self._the_conn
147
148    def _set_conn(self, conn):
149        self._the_conn = conn
150
151    conn = property(_get_conn, _set_conn)
152
153    # for use with async connections only
154    def wait(self, cur_or_conn):
155        pollable = cur_or_conn
156        if not hasattr(pollable, 'poll'):
157            pollable = cur_or_conn.connection
158        while True:
159            state = pollable.poll()
160            if state == psycopg2.extensions.POLL_OK:
161                break
162            elif state == psycopg2.extensions.POLL_READ:
163                select.select([pollable], [], [], 1)
164            elif state == psycopg2.extensions.POLL_WRITE:
165                select.select([], [pollable], [], 1)
166            else:
167                raise Exception("Unexpected result from poll: %r", state)
168
169    _libpq = None
170
171    @property
172    def libpq(self):
173        """Return a ctypes wrapper for the libpq library"""
174        if ConnectingTestCase._libpq is not None:
175            return ConnectingTestCase._libpq
176
177        libname = find_library('pq')
178        if libname is None and platform.system() == 'Windows':
179            raise self.skipTest("can't import libpq on windows")
180
181        try:
182            rv = ConnectingTestCase._libpq = ctypes.pydll.LoadLibrary(libname)
183        except OSError as e:
184            raise self.skipTest("couldn't open libpq for testing: %s" % e)
185        return rv
186
187
188def decorate_all_tests(obj, *decorators):
189    """
190    Apply all the *decorators* to all the tests defined in the TestCase *obj*.
191
192    The decorator can also be applied to a decorator: if *obj* is a function,
193    return a new decorator which can be applied either to a method or to a
194    class, in which case it will decorate all the tests.
195    """
196    if isinstance(obj, types.FunctionType):
197        def decorator(func_or_cls):
198            if isinstance(func_or_cls, types.FunctionType):
199                return obj(func_or_cls)
200            else:
201                decorate_all_tests(func_or_cls, obj)
202                return func_or_cls
203
204        return decorator
205
206    for n in dir(obj):
207        if n.startswith('test'):
208            for d in decorators:
209                setattr(obj, n, d(getattr(obj, n)))
210
211
212@decorate_all_tests
213def skip_if_no_uuid(f):
214    """Decorator to skip a test if uuid is not supported by PG."""
215    @wraps(f)
216    def skip_if_no_uuid_(self):
217        try:
218            cur = self.conn.cursor()
219            cur.execute("select typname from pg_type where typname = 'uuid'")
220            has = cur.fetchone()
221        finally:
222            self.conn.rollback()
223
224        if has:
225            return f(self)
226        else:
227            return self.skipTest("uuid type not available on the server")
228
229    return skip_if_no_uuid_
230
231
232@decorate_all_tests
233def skip_if_tpc_disabled(f):
234    """Skip a test if the server has tpc support disabled."""
235    @wraps(f)
236    def skip_if_tpc_disabled_(self):
237        cnn = self.connect()
238        skip_if_crdb("2-phase commit", cnn)
239
240        cur = cnn.cursor()
241        try:
242            cur.execute("SHOW max_prepared_transactions;")
243        except psycopg2.ProgrammingError:
244            return self.skipTest(
245                "server too old: two phase transactions not supported.")
246        else:
247            mtp = int(cur.fetchone()[0])
248        cnn.close()
249
250        if not mtp:
251            return self.skipTest(
252                "server not configured for two phase transactions. "
253                "set max_prepared_transactions to > 0 to run the test")
254        return f(self)
255
256    return skip_if_tpc_disabled_
257
258
259def skip_before_postgres(*ver):
260    """Skip a test on PostgreSQL before a certain version."""
261    reason = None
262    if isinstance(ver[-1], str):
263        ver, reason = ver[:-1], ver[-1]
264
265    ver = ver + (0,) * (3 - len(ver))
266
267    @decorate_all_tests
268    def skip_before_postgres_(f):
269        @wraps(f)
270        def skip_before_postgres__(self):
271            if self.conn.info.server_version < int("%d%02d%02d" % ver):
272                return self.skipTest(
273                    reason or "skipped because PostgreSQL %s"
274                    % self.conn.info.server_version)
275            else:
276                return f(self)
277
278        return skip_before_postgres__
279    return skip_before_postgres_
280
281
282def skip_after_postgres(*ver):
283    """Skip a test on PostgreSQL after (including) a certain version."""
284    ver = ver + (0,) * (3 - len(ver))
285
286    @decorate_all_tests
287    def skip_after_postgres_(f):
288        @wraps(f)
289        def skip_after_postgres__(self):
290            if self.conn.info.server_version >= int("%d%02d%02d" % ver):
291                return self.skipTest("skipped because PostgreSQL %s"
292                    % self.conn.info.server_version)
293            else:
294                return f(self)
295
296        return skip_after_postgres__
297    return skip_after_postgres_
298
299
300def libpq_version():
301    v = psycopg2.__libpq_version__
302    if v >= 90100:
303        v = min(v, psycopg2.extensions.libpq_version())
304    return v
305
306
307def skip_before_libpq(*ver):
308    """Skip a test if libpq we're linked to is older than a certain version."""
309    ver = ver + (0,) * (3 - len(ver))
310
311    def skip_before_libpq_(cls):
312        v = libpq_version()
313        decorator = unittest.skipIf(
314            v < int("%d%02d%02d" % ver),
315            f"skipped because libpq {v}",
316        )
317        return decorator(cls)
318    return skip_before_libpq_
319
320
321def skip_after_libpq(*ver):
322    """Skip a test if libpq we're linked to is newer than a certain version."""
323    ver = ver + (0,) * (3 - len(ver))
324
325    def skip_after_libpq_(cls):
326        v = libpq_version()
327        decorator = unittest.skipIf(
328            v >= int("%d%02d%02d" % ver),
329            f"skipped because libpq {v}",
330        )
331        return decorator(cls)
332    return skip_after_libpq_
333
334
335def skip_before_python(*ver):
336    """Skip a test on Python before a certain version."""
337    def skip_before_python_(cls):
338        decorator = unittest.skipIf(
339            sys.version_info[:len(ver)] < ver,
340            f"skipped because Python {'.'.join(map(str, sys.version_info[:len(ver)]))}",
341        )
342        return decorator(cls)
343    return skip_before_python_
344
345
346def skip_from_python(*ver):
347    """Skip a test on Python after (including) a certain version."""
348    def skip_from_python_(cls):
349        decorator = unittest.skipIf(
350            sys.version_info[:len(ver)] >= ver,
351            f"skipped because Python {'.'.join(map(str, sys.version_info[:len(ver)]))}",
352        )
353        return decorator(cls)
354    return skip_from_python_
355
356
357@decorate_all_tests
358def skip_if_no_superuser(f):
359    """Skip a test if the database user running the test is not a superuser"""
360    @wraps(f)
361    def skip_if_no_superuser_(self):
362        try:
363            return f(self)
364        except psycopg2.errors.InsufficientPrivilege:
365            self.skipTest("skipped because not superuser")
366
367    return skip_if_no_superuser_
368
369
370def skip_if_green(reason):
371    def skip_if_green_(cls):
372        decorator = unittest.skipIf(green, reason)
373        return decorator(cls)
374    return skip_if_green_
375
376
377skip_copy_if_green = skip_if_green("copy in async mode currently not supported")
378
379
380def skip_if_no_getrefcount(cls):
381    decorator = unittest.skipUnless(
382        hasattr(sys, 'getrefcount'),
383        'no sys.getrefcount()',
384    )
385    return decorator(cls)
386
387
388def skip_if_windows(cls):
389    """Skip a test if run on windows"""
390    decorator = unittest.skipIf(
391        platform.system() == 'Windows',
392        "Not supported on Windows",
393    )
394    return decorator(cls)
395
396
397def crdb_version(conn, __crdb_version=[]):
398    """
399    Return the CockroachDB version if that's the db being tested, else None.
400
401    Return the number as an integer similar to PQserverVersion: return
402    v20.1.3 as 200103.
403
404    Assume all the connections are on the same db: return a cached result on
405    following calls.
406
407    """
408    if __crdb_version:
409        return __crdb_version[0]
410
411    sver = conn.info.parameter_status("crdb_version")
412    if sver is None:
413        __crdb_version.append(None)
414    else:
415        m = re.search(r"\bv(\d+)\.(\d+)\.(\d+)", sver)
416        if not m:
417            raise ValueError(
418                f"can't parse CockroachDB version from {sver}")
419
420        ver = int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
421        __crdb_version.append(ver)
422
423    return __crdb_version[0]
424
425
426def skip_if_crdb(reason, conn=None, version=None):
427    """Skip a test or test class if we are testing against CockroachDB.
428
429    Can be used as a decorator for tests function or classes:
430
431        @skip_if_crdb("my reason")
432        class SomeUnitTest(UnitTest):
433            # ...
434
435    Or as a normal function if the *conn* argument is passed.
436
437    If *version* is specified it should be a string such as ">= 20.1", "< 20",
438    "== 20.1.3": the test will be skipped only if the version matches.
439
440    """
441    if not isinstance(reason, str):
442        raise TypeError(f"reason should be a string, got {reason!r} instead")
443
444    if conn is not None:
445        ver = crdb_version(conn)
446        if ver is not None and _crdb_match_version(ver, version):
447            if reason in crdb_reasons:
448                reason = (
449                    "%s (https://github.com/cockroachdb/cockroach/issues/%s)"
450                    % (reason, crdb_reasons[reason]))
451            raise unittest.SkipTest(
452                f"not supported on CockroachDB {ver}: {reason}")
453
454    @decorate_all_tests
455    def skip_if_crdb_(f):
456        @wraps(f)
457        def skip_if_crdb__(self, *args, **kwargs):
458            skip_if_crdb(reason, conn=self.connect(), version=version)
459            return f(self, *args, **kwargs)
460
461        return skip_if_crdb__
462
463    return skip_if_crdb_
464
465
466# mapping from reason description to ticket number
467crdb_reasons = {
468    "2-phase commit": 22329,
469    "backend pid": 35897,
470    "cancel": 41335,
471    "cast adds tz": 51692,
472    "cidr": 18846,
473    "composite": 27792,
474    "copy": 41608,
475    "deferrable": 48307,
476    "encoding": 35882,
477    "hstore": 41284,
478    "infinity date": 41564,
479    "interval style": 35807,
480    "large objects": 243,
481    "named cursor": 41412,
482    "nested array": 32552,
483    "notify": 41522,
484    "password_encryption": 42519,
485    "range": 41282,
486    "stored procedure": 1751,
487}
488
489
490def _crdb_match_version(version, pattern):
491    if pattern is None:
492        return True
493
494    m = re.match(r'^(>|>=|<|<=|==|!=)\s*(\d+)(?:\.(\d+))?(?:\.(\d+))?$', pattern)
495    if m is None:
496        raise ValueError(
497            "bad crdb version pattern %r: should be 'OP MAJOR[.MINOR[.BUGFIX]]'"
498            % pattern)
499
500    ops = {'>': 'gt', '>=': 'ge', '<': 'lt', '<=': 'le', '==': 'eq', '!=': 'ne'}
501    op = getattr(operator, ops[m.group(1)])
502    ref = int(m.group(2)) * 10000 + int(m.group(3) or 0) * 100 + int(m.group(4) or 0)
503    return op(version, ref)
504
505
506class raises_typeerror:
507    def __enter__(self):
508        pass
509
510    def __exit__(self, type, exc, tb):
511        assert type is TypeError
512        return True
513
514
515def slow(f):
516    """Decorator to mark slow tests we may want to skip
517
518    Note: in order to find slow tests you can run:
519
520    make check 2>&1 | ts -i "%.s" | sort -n
521    """
522    @wraps(f)
523    def slow_(self):
524        if os.environ.get('PSYCOPG2_TEST_FAST', '0') != '0':
525            return self.skipTest("slow test")
526        return f(self)
527    return slow_
528
529
530def restore_types(f):
531    """Decorator to restore the adaptation system after running a test"""
532    @wraps(f)
533    def restore_types_(self):
534        types = psycopg2.extensions.string_types.copy()
535        adapters = psycopg2.extensions.adapters.copy()
536        try:
537            return f(self)
538        finally:
539            psycopg2.extensions.string_types.clear()
540            psycopg2.extensions.string_types.update(types)
541            psycopg2.extensions.adapters.clear()
542            psycopg2.extensions.adapters.update(adapters)
543
544    return restore_types_
545