1"""
2Entry point into the adaptation system.
3"""
4
5# Copyright (C) 2020-2021 The Psycopg Team
6
7from abc import ABC, abstractmethod
8from typing import Any, Optional, Type, Tuple, Union, TYPE_CHECKING
9
10from . import pq, abc
11from . import _adapters_map
12from ._enums import PyFormat as PyFormat
13from ._cmodule import _psycopg
14
15if TYPE_CHECKING:
16    from .connection import BaseConnection
17
18AdaptersMap = _adapters_map.AdaptersMap
19Buffer = abc.Buffer
20
21ORD_BS = ord("\\")
22
23
24class Dumper(abc.Dumper, ABC):
25    """
26    Convert Python object of the type *cls* to PostgreSQL representation.
27    """
28
29    oid: int = 0
30    """The oid to pass to the server, if known."""
31
32    format: pq.Format = pq.Format.TEXT
33    """The format of the data dumped."""
34
35    def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
36        self.cls = cls
37        self.connection: Optional["BaseConnection[Any]"] = (
38            context.connection if context else None
39        )
40
41    def __repr__(self) -> str:
42        return (
43            f"<{type(self).__module__}.{type(self).__qualname__}"
44            f" (oid={self.oid}) at 0x{id(self):x}>"
45        )
46
47    @abstractmethod
48    def dump(self, obj: Any) -> Buffer:
49        ...
50
51    def quote(self, obj: Any) -> Buffer:
52        """
53        By default return the `dump()` value quoted and sanitised, so
54        that the result can be used to build a SQL string. This works well
55        for most types and you won't likely have to implement this method in a
56        subclass.
57        """
58        value = self.dump(obj)
59
60        if self.connection:
61            esc = pq.Escaping(self.connection.pgconn)
62            # escaping and quoting
63            return esc.escape_literal(value)
64
65        # This path is taken when quote is asked without a connection,
66        # usually it means by psycopg.sql.quote() or by
67        # 'Composible.as_string(None)'. Most often than not this is done by
68        # someone generating a SQL file to consume elsewhere.
69
70        # No quoting, only quote escaping, random bs escaping. See further.
71        esc = pq.Escaping()
72        out = esc.escape_string(value)
73
74        # b"\\" in memoryview doesn't work so search for the ascii value
75        if ORD_BS not in out:
76            # If the string has no backslash, the result is correct and we
77            # don't need to bother with standard_conforming_strings.
78            return b"'" + out + b"'"
79
80        # The libpq has a crazy behaviour: PQescapeString uses the last
81        # standard_conforming_strings setting seen on a connection. This
82        # means that backslashes might be escaped or might not.
83        #
84        # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
85        # if scs is off, '\\' raises a warning and '\' is an error.
86        #
87        # Check what the libpq does, and if it doesn't escape the backslash
88        # let's do it on our own. Never mind the race condition.
89        rv: bytes = b" E'" + out + b"'"
90        if esc.escape_string(b"\\") == b"\\":
91            rv = rv.replace(b"\\", b"\\\\")
92        return rv
93
94    def get_key(
95        self, obj: Any, format: PyFormat
96    ) -> Union[type, Tuple[type, ...]]:
97        """
98        Implementation of the `~psycopg.abc.Dumper.get_key()` member of the
99        `~psycopg.abc.Dumper` protocol. Look at its definition for details.
100
101        This implementation returns the *cls* passed in the constructor.
102        Subclasses needing to specialise the PostgreSQL type according to the
103        *value* of the object dumped (not only according to to its type)
104        should override this class.
105
106        """
107        return self.cls
108
109    def upgrade(self, obj: Any, format: PyFormat) -> "Dumper":
110        """
111        Implementation of the `~psycopg.abc.Dumper.upgrade()` member of the
112        `~psycopg.abc.Dumper` protocol. Look at its definition for details.
113
114        This implementation just returns *self*. If a subclass implements
115        `get_key()` it should probably override `!upgrade()` too.
116        """
117        return self
118
119
120class Loader(ABC):
121    """
122    Convert PostgreSQL objects with OID *oid* to Python objects.
123    """
124
125    format: pq.Format = pq.Format.TEXT
126    """The format of the data loaded."""
127
128    def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
129        self.oid = oid
130        self.connection: Optional["BaseConnection[Any]"] = (
131            context.connection if context else None
132        )
133
134    @abstractmethod
135    def load(self, data: Buffer) -> Any:
136        """Convert a PostgreSQL value to a Python object."""
137        ...
138
139
140Transformer: Type["abc.Transformer"]
141
142# Override it with fast object if available
143if _psycopg:
144    Transformer = _psycopg.Transformer
145else:
146    from . import _transform
147
148    Transformer = _transform.Transformer
149
150
151class RecursiveDumper(Dumper):
152    """Dumper with a transformer to help dumping recursive types."""
153
154    def __init__(self, cls: type, context: Optional[abc.AdaptContext] = None):
155        super().__init__(cls, context)
156        self._tx = Transformer(context)
157
158
159class RecursiveLoader(Loader):
160    """Loader with a transformer to help loading recursive types."""
161
162    def __init__(self, oid: int, context: Optional[abc.AdaptContext] = None):
163        super().__init__(oid, context)
164        self._tx = Transformer(context)
165