1"""
2SQL composition utility module
3"""
4
5# Copyright (C) 2020-2021 The Psycopg Team
6
7import codecs
8import string
9from abc import ABC, abstractmethod
10from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union
11
12from .pq import Escaping
13from .abc import AdaptContext
14from .adapt import Transformer, PyFormat
15from ._encodings import pgconn_encoding
16
17
18def quote(obj: Any, context: Optional[AdaptContext] = None) -> str:
19    """
20    Adapt a Python object to a quoted SQL string.
21
22    Use this function only if you absolutely want to convert a Python string to
23    an SQL quoted literal to use e.g. to generate batch SQL and you won't have
24    a connection avaliable when you will need to use it.
25
26    This function is relatively inefficient, because it doesn't cache the
27    adaptation rules. If you pass a *context* you can adapt the adaptation
28    rules used, otherwise only global rules are used.
29
30    """
31    return Literal(obj).as_string(context)
32
33
34class Composable(ABC):
35    """
36    Abstract base class for objects that can be used to compose an SQL string.
37
38    `!Composable` objects can be passed directly to
39    `~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`,
40    `~psycopg.Cursor.copy()` in place of the query string.
41
42    `!Composable` objects can be joined using the ``+`` operator: the result
43    will be a `Composed` instance containing the objects joined. The operator
44    ``*`` is also supported with an integer argument: the result is a
45    `!Composed` instance containing the left argument repeated as many times as
46    requested.
47    """
48
49    def __init__(self, obj: Any):
50        self._obj = obj
51
52    def __repr__(self) -> str:
53        return f"{self.__class__.__name__}({self._obj!r})"
54
55    @abstractmethod
56    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
57        """
58        Return the value of the object as bytes.
59
60        :param context: the context to evaluate the object into.
61        :type context: `connection` or `cursor`
62
63        The method is automatically invoked by `~psycopg.Cursor.execute()`,
64        `~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a
65        `!Composable` is passed instead of the query string.
66
67        """
68        raise NotImplementedError
69
70    def as_string(self, context: Optional[AdaptContext]) -> str:
71        """
72        Return the value of the object as string.
73
74        :param context: the context to evaluate the string into.
75        :type context: `connection` or `cursor`
76
77        """
78        conn = context.connection if context else None
79        enc = pgconn_encoding(conn.pgconn) if conn else "utf-8"
80        b = self.as_bytes(context)
81        if isinstance(b, bytes):
82            return b.decode(enc)
83        else:
84            # buffer object
85            return codecs.lookup(enc).decode(b)[0]
86
87    def __add__(self, other: "Composable") -> "Composed":
88        if isinstance(other, Composed):
89            return Composed([self]) + other
90        if isinstance(other, Composable):
91            return Composed([self]) + Composed([other])
92        else:
93            return NotImplemented
94
95    def __mul__(self, n: int) -> "Composed":
96        return Composed([self] * n)
97
98    def __eq__(self, other: Any) -> bool:
99        return type(self) is type(other) and self._obj == other._obj
100
101    def __ne__(self, other: Any) -> bool:
102        return not self.__eq__(other)
103
104
105class Composed(Composable):
106    """
107    A `Composable` object made of a sequence of `!Composable`.
108
109    The object is usually created using `!Composable` operators and methods.
110    However it is possible to create a `!Composed` directly specifying a
111    sequence of objects as arguments: if they are not `!Composable` they will
112    be wrapped in a `Literal`.
113
114    Example::
115
116        >>> comp = sql.Composed(
117        ...     [sql.SQL("INSERT INTO "), sql.Identifier("table")])
118        >>> print(comp.as_string(conn))
119        INSERT INTO "table"
120
121    `!Composed` objects are iterable (so they can be used in `SQL.join` for
122    instance).
123    """
124
125    _obj: List[Composable]
126
127    def __init__(self, seq: Sequence[Any]):
128        seq = [
129            obj if isinstance(obj, Composable) else Literal(obj) for obj in seq
130        ]
131        super().__init__(seq)
132
133    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
134        return b"".join(obj.as_bytes(context) for obj in self._obj)
135
136    def __iter__(self) -> Iterator[Composable]:
137        return iter(self._obj)
138
139    def __add__(self, other: Composable) -> "Composed":
140        if isinstance(other, Composed):
141            return Composed(self._obj + other._obj)
142        if isinstance(other, Composable):
143            return Composed(self._obj + [other])
144        else:
145            return NotImplemented
146
147    def join(self, joiner: Union["SQL", str]) -> "Composed":
148        """
149        Return a new `!Composed` interposing the *joiner* with the `!Composed` items.
150
151        The *joiner* must be a `SQL` or a string which will be interpreted as
152        an `SQL`.
153
154        Example::
155
156            >>> fields = sql.Identifier('foo') + sql.Identifier('bar')  # a Composed
157            >>> print(fields.join(', ').as_string(conn))
158            "foo", "bar"
159
160        """
161        if isinstance(joiner, str):
162            joiner = SQL(joiner)
163        elif not isinstance(joiner, SQL):
164            raise TypeError(
165                f"Composed.join() argument must be strings or SQL,"
166                f" got {joiner!r} instead"
167            )
168
169        return joiner.join(self._obj)
170
171
172class SQL(Composable):
173    """
174    A `Composable` representing a snippet of SQL statement.
175
176    `!SQL` exposes `join()` and `format()` methods useful to create a template
177    where to merge variable parts of a query (for instance field or table
178    names).
179
180    The *string* doesn't undergo any form of escaping, so it is not suitable to
181    represent variable identifiers or values: you should only use it to pass
182    constant strings representing templates or snippets of SQL statements; use
183    other objects such as `Identifier` or `Literal` to represent variable
184    parts.
185
186    Example::
187
188        >>> query = sql.SQL("SELECT {0} FROM {1}").format(
189        ...    sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
190        ...    sql.Identifier('table'))
191        >>> print(query.as_string(conn))
192        SELECT "foo", "bar" FROM "table"
193    """
194
195    _obj: str
196    _formatter = string.Formatter()
197
198    def __init__(self, obj: str):
199        super().__init__(obj)
200        if not isinstance(obj, str):
201            raise TypeError(f"SQL values must be strings, got {obj!r} instead")
202
203    def as_string(self, context: Optional[AdaptContext]) -> str:
204        return self._obj
205
206    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
207        enc = "utf-8"
208        if context:
209            conn = context.connection
210            if conn:
211                enc = pgconn_encoding(conn.pgconn)
212        return self._obj.encode(enc)
213
214    def format(self, *args: Any, **kwargs: Any) -> Composed:
215        """
216        Merge `Composable` objects into a template.
217
218        :param args: parameters to replace to numbered (``{0}``, ``{1}``) or
219            auto-numbered (``{}``) placeholders
220        :param kwargs: parameters to replace to named (``{name}``) placeholders
221        :return: the union of the `!SQL` string with placeholders replaced
222        :rtype: `Composed`
223
224        The method is similar to the Python `str.format()` method: the string
225        template supports auto-numbered (``{}``), numbered (``{0}``,
226        ``{1}``...), and named placeholders (``{name}``), with positional
227        arguments replacing the numbered placeholders and keywords replacing
228        the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``)
229        are not supported.
230
231        If a `!Composable` objects is passed to the template it will be merged
232        according to its `as_string()` method. If any other Python object is
233        passed, it will be wrapped in a `Literal` object and so escaped
234        according to SQL rules.
235
236        Example::
237
238            >>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s")
239            ...     .format(sql.Identifier('people'), sql.Identifier('id'))
240            ...     .as_string(conn))
241            SELECT * FROM "people" WHERE "id" = %s
242
243            >>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}")
244            ...     .format(tbl=sql.Identifier('people'), name="O'Rourke"))
245            ...     .as_string(conn))
246            SELECT * FROM "people" WHERE name = 'O''Rourke'
247
248        """
249        rv: List[Composable] = []
250        autonum: Optional[int] = 0
251        for pre, name, spec, conv in self._formatter.parse(self._obj):
252            if spec:
253                raise ValueError("no format specification supported by SQL")
254            if conv:
255                raise ValueError("no format conversion supported by SQL")
256            if pre:
257                rv.append(SQL(pre))
258
259            if name is None:
260                continue
261
262            if name.isdigit():
263                if autonum:
264                    raise ValueError(
265                        "cannot switch from automatic field numbering to manual"
266                    )
267                rv.append(args[int(name)])
268                autonum = None
269
270            elif not name:
271                if autonum is None:
272                    raise ValueError(
273                        "cannot switch from manual field numbering to automatic"
274                    )
275                rv.append(args[autonum])
276                autonum += 1
277
278            else:
279                rv.append(kwargs[name])
280
281        return Composed(rv)
282
283    def join(self, seq: Iterable[Composable]) -> Composed:
284        """
285        Join a sequence of `Composable`.
286
287        :param seq: the elements to join.
288        :type seq: iterable of `!Composable`
289
290        Use the `!SQL` object's *string* to separate the elements in *seq*.
291        Note that `Composed` objects are iterable too, so they can be used as
292        argument for this method.
293
294        Example::
295
296            >>> snip = sql.SQL(', ').join(
297            ...     sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
298            >>> print(snip.as_string(conn))
299            "foo", "bar", "baz"
300        """
301        rv = []
302        it = iter(seq)
303        try:
304            rv.append(next(it))
305        except StopIteration:
306            pass
307        else:
308            for i in it:
309                rv.append(self)
310                rv.append(i)
311
312        return Composed(rv)
313
314
315class Identifier(Composable):
316    """
317    A `Composable` representing an SQL identifier or a dot-separated sequence.
318
319    Identifiers usually represent names of database objects, such as tables or
320    fields. PostgreSQL identifiers follow `different rules`__ than SQL string
321    literals for escaping (e.g. they use double quotes instead of single).
322
323    .. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \
324        SQL-SYNTAX-IDENTIFIERS
325
326    Example::
327
328        >>> t1 = sql.Identifier("foo")
329        >>> t2 = sql.Identifier("ba'r")
330        >>> t3 = sql.Identifier('ba"z')
331        >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn))
332        "foo", "ba'r", "ba""z"
333
334    Multiple strings can be passed to the object to represent a qualified name,
335    i.e. a dot-separated sequence of identifiers.
336
337    Example::
338
339        >>> query = sql.SQL("SELECT {} FROM {}").format(
340        ...     sql.Identifier("table", "field"),
341        ...     sql.Identifier("schema", "table"))
342        >>> print(query.as_string(conn))
343        SELECT "table"."field" FROM "schema"."table"
344
345    """
346
347    _obj: Sequence[str]
348
349    def __init__(self, *strings: str):
350        # init super() now to make the __repr__ not explode in case of error
351        super().__init__(strings)
352
353        if not strings:
354            raise TypeError("Identifier cannot be empty")
355
356        for s in strings:
357            if not isinstance(s, str):
358                raise TypeError(
359                    f"SQL identifier parts must be strings, got {s!r} instead"
360                )
361
362    def __repr__(self) -> str:
363        return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})"
364
365    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
366        conn = context.connection if context else None
367        if not conn:
368            raise ValueError("a connection is necessary for Identifier")
369        esc = Escaping(conn.pgconn)
370        enc = pgconn_encoding(conn.pgconn)
371        escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj]
372        return b".".join(escs)
373
374
375class Literal(Composable):
376    """
377    A `Composable` representing an SQL value to include in a query.
378
379    Usually you will want to include placeholders in the query and pass values
380    as `~cursor.execute()` arguments. If however you really really need to
381    include a literal value in the query you can use this object.
382
383    The string returned by `!as_string()` follows the normal :ref:`adaptation
384    rules <types-adaptation>` for Python objects.
385
386    Example::
387
388        >>> s1 = sql.Literal("foo")
389        >>> s2 = sql.Literal("ba'r")
390        >>> s3 = sql.Literal(42)
391        >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
392        'foo', 'ba''r', 42
393
394    """
395
396    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
397        tx = Transformer(context)
398        dumper = tx.get_dumper(self._obj, PyFormat.TEXT)
399        return dumper.quote(self._obj)
400
401
402class Placeholder(Composable):
403    """A `Composable` representing a placeholder for query parameters.
404
405    If the name is specified, generate a named placeholder (e.g. ``%(name)s``,
406    ``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``,
407    ``%b``).
408
409    The object is useful to generate SQL queries with a variable number of
410    arguments.
411
412    Examples::
413
414        >>> names = ['foo', 'bar', 'baz']
415
416        >>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
417        ...     sql.SQL(', ').join(map(sql.Identifier, names)),
418        ...     sql.SQL(', ').join(sql.Placeholder() * len(names)))
419        >>> print(q1.as_string(conn))
420        INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s)
421
422        >>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format(
423        ...     sql.SQL(', ').join(map(sql.Identifier, names)),
424        ...     sql.SQL(', ').join(map(sql.Placeholder, names)))
425        >>> print(q2.as_string(conn))
426        INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s)
427
428    """
429
430    def __init__(self, name: str = "", format: PyFormat = PyFormat.AUTO):
431        super().__init__(name)
432        if not isinstance(name, str):
433            raise TypeError(f"expected string as name, got {name!r}")
434
435        if ")" in name:
436            raise ValueError(f"invalid name: {name!r}")
437
438        self._format = format
439
440    def __repr__(self) -> str:
441        parts = []
442        if self._obj:
443            parts.append(repr(self._obj))
444        if self._format != PyFormat.AUTO:
445            parts.append(f"format={PyFormat(self._format).name}")
446
447        return f"{self.__class__.__name__}({', '.join(parts)})"
448
449    def as_string(self, context: Optional[AdaptContext]) -> str:
450        code = self._format
451        return f"%({self._obj}){code}" if self._obj else f"%{code}"
452
453    def as_bytes(self, context: Optional[AdaptContext]) -> bytes:
454        conn = context.connection if context else None
455        enc = pgconn_encoding(conn.pgconn) if conn else "utf-8"
456        return self.as_string(context).encode(enc)
457
458
459# Literals
460NULL = SQL("NULL")
461DEFAULT = SQL("DEFAULT")
462