1"""SQL composition utility module
2"""
3
4# psycopg/sql.py - SQL composition utility module
5#
6# Copyright (C) 2016-2019 Daniele Varrazzo  <daniele.varrazzo@gmail.com>
7# Copyright (C) 2020-2021 The Psycopg Team
8#
9# psycopg2 is free software: you can redistribute it and/or modify it
10# under the terms of the GNU Lesser General Public License as published
11# by the Free Software Foundation, either version 3 of the License, or
12# (at your option) any later version.
13#
14# In addition, as a special exception, the copyright holders give
15# permission to link this program with the OpenSSL library (or with
16# modified versions of OpenSSL that use the same license as OpenSSL),
17# and distribute linked combinations including the two.
18#
19# You must obey the GNU Lesser General Public License in all respects for
20# all of the code used other than OpenSSL.
21#
22# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
23# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
24# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
25# License for more details.
26
27import string
28
29from psycopg2 import extensions as ext
30
31
32_formatter = string.Formatter()
33
34
35class Composable:
36    """
37    Abstract base class for objects that can be used to compose an SQL string.
38
39    `!Composable` objects can be passed directly to `~cursor.execute()`,
40    `~cursor.executemany()`, `~cursor.copy_expert()` in place of the query
41    string.
42
43    `!Composable` objects can be joined using the ``+`` operator: the result
44    will be a `Composed` instance containing the objects joined. The operator
45    ``*`` is also supported with an integer argument: the result is a
46    `!Composed` instance containing the left argument repeated as many times as
47    requested.
48    """
49    def __init__(self, wrapped):
50        self._wrapped = wrapped
51
52    def __repr__(self):
53        return f"{self.__class__.__name__}({self._wrapped!r})"
54
55    def as_string(self, context):
56        """
57        Return the string value of the object.
58
59        :param context: the context to evaluate the string into.
60        :type context: `connection` or `cursor`
61
62        The method is automatically invoked by `~cursor.execute()`,
63        `~cursor.executemany()`, `~cursor.copy_expert()` if a `!Composable` is
64        passed instead of the query string.
65        """
66        raise NotImplementedError
67
68    def __add__(self, other):
69        if isinstance(other, Composed):
70            return Composed([self]) + other
71        if isinstance(other, Composable):
72            return Composed([self]) + Composed([other])
73        else:
74            return NotImplemented
75
76    def __mul__(self, n):
77        return Composed([self] * n)
78
79    def __eq__(self, other):
80        return type(self) is type(other) and self._wrapped == other._wrapped
81
82    def __ne__(self, other):
83        return not self.__eq__(other)
84
85
86class Composed(Composable):
87    """
88    A `Composable` object made of a sequence of `!Composable`.
89
90    The object is usually created using `!Composable` operators and methods.
91    However it is possible to create a `!Composed` directly specifying a
92    sequence of `!Composable` as arguments.
93
94    Example::
95
96        >>> comp = sql.Composed(
97        ...     [sql.SQL("insert into "), sql.Identifier("table")])
98        >>> print(comp.as_string(conn))
99        insert into "table"
100
101    `!Composed` objects are iterable (so they can be used in `SQL.join` for
102    instance).
103    """
104    def __init__(self, seq):
105        wrapped = []
106        for i in seq:
107            if not isinstance(i, Composable):
108                raise TypeError(
109                    f"Composed elements must be Composable, got {i!r} instead")
110            wrapped.append(i)
111
112        super().__init__(wrapped)
113
114    @property
115    def seq(self):
116        """The list of the content of the `!Composed`."""
117        return list(self._wrapped)
118
119    def as_string(self, context):
120        rv = []
121        for i in self._wrapped:
122            rv.append(i.as_string(context))
123        return ''.join(rv)
124
125    def __iter__(self):
126        return iter(self._wrapped)
127
128    def __add__(self, other):
129        if isinstance(other, Composed):
130            return Composed(self._wrapped + other._wrapped)
131        if isinstance(other, Composable):
132            return Composed(self._wrapped + [other])
133        else:
134            return NotImplemented
135
136    def join(self, joiner):
137        """
138        Return a new `!Composed` interposing the *joiner* with the `!Composed` items.
139
140        The *joiner* must be a `SQL` or a string which will be interpreted as
141        an `SQL`.
142
143        Example::
144
145            >>> fields = sql.Identifier('foo') + sql.Identifier('bar')  # a Composed
146            >>> print(fields.join(', ').as_string(conn))
147            "foo", "bar"
148
149        """
150        if isinstance(joiner, str):
151            joiner = SQL(joiner)
152        elif not isinstance(joiner, SQL):
153            raise TypeError(
154                "Composed.join() argument must be a string or an SQL")
155
156        return joiner.join(self)
157
158
159class SQL(Composable):
160    """
161    A `Composable` representing a snippet of SQL statement.
162
163    `!SQL` exposes `join()` and `format()` methods useful to create a template
164    where to merge variable parts of a query (for instance field or table
165    names).
166
167    The *string* doesn't undergo any form of escaping, so it is not suitable to
168    represent variable identifiers or values: you should only use it to pass
169    constant strings representing templates or snippets of SQL statements; use
170    other objects such as `Identifier` or `Literal` to represent variable
171    parts.
172
173    Example::
174
175        >>> query = sql.SQL("select {0} from {1}").format(
176        ...    sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]),
177        ...    sql.Identifier('table'))
178        >>> print(query.as_string(conn))
179        select "foo", "bar" from "table"
180    """
181    def __init__(self, string):
182        if not isinstance(string, str):
183            raise TypeError("SQL values must be strings")
184        super().__init__(string)
185
186    @property
187    def string(self):
188        """The string wrapped by the `!SQL` object."""
189        return self._wrapped
190
191    def as_string(self, context):
192        return self._wrapped
193
194    def format(self, *args, **kwargs):
195        """
196        Merge `Composable` objects into a template.
197
198        :param `Composable` args: parameters to replace to numbered
199            (``{0}``, ``{1}``) or auto-numbered (``{}``) placeholders
200        :param `Composable` kwargs: parameters to replace to named (``{name}``)
201            placeholders
202        :return: the union of the `!SQL` string with placeholders replaced
203        :rtype: `Composed`
204
205        The method is similar to the Python `str.format()` method: the string
206        template supports auto-numbered (``{}``), numbered (``{0}``,
207        ``{1}``...), and named placeholders (``{name}``), with positional
208        arguments replacing the numbered placeholders and keywords replacing
209        the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``)
210        are not supported. Only `!Composable` objects can be passed to the
211        template.
212
213        Example::
214
215            >>> print(sql.SQL("select * from {} where {} = %s")
216            ...     .format(sql.Identifier('people'), sql.Identifier('id'))
217            ...     .as_string(conn))
218            select * from "people" where "id" = %s
219
220            >>> print(sql.SQL("select * from {tbl} where {pkey} = %s")
221            ...     .format(tbl=sql.Identifier('people'), pkey=sql.Identifier('id'))
222            ...     .as_string(conn))
223            select * from "people" where "id" = %s
224
225        """
226        rv = []
227        autonum = 0
228        for pre, name, spec, conv in _formatter.parse(self._wrapped):
229            if spec:
230                raise ValueError("no format specification supported by SQL")
231            if conv:
232                raise ValueError("no format conversion supported by SQL")
233            if pre:
234                rv.append(SQL(pre))
235
236            if name is None:
237                continue
238
239            if name.isdigit():
240                if autonum:
241                    raise ValueError(
242                        "cannot switch from automatic field numbering to manual")
243                rv.append(args[int(name)])
244                autonum = None
245
246            elif not name:
247                if autonum is None:
248                    raise ValueError(
249                        "cannot switch from manual field numbering to automatic")
250                rv.append(args[autonum])
251                autonum += 1
252
253            else:
254                rv.append(kwargs[name])
255
256        return Composed(rv)
257
258    def join(self, seq):
259        """
260        Join a sequence of `Composable`.
261
262        :param seq: the elements to join.
263        :type seq: iterable of `!Composable`
264
265        Use the `!SQL` object's *string* to separate the elements in *seq*.
266        Note that `Composed` objects are iterable too, so they can be used as
267        argument for this method.
268
269        Example::
270
271            >>> snip = sql.SQL(', ').join(
272            ...     sql.Identifier(n) for n in ['foo', 'bar', 'baz'])
273            >>> print(snip.as_string(conn))
274            "foo", "bar", "baz"
275        """
276        rv = []
277        it = iter(seq)
278        try:
279            rv.append(next(it))
280        except StopIteration:
281            pass
282        else:
283            for i in it:
284                rv.append(self)
285                rv.append(i)
286
287        return Composed(rv)
288
289
290class Identifier(Composable):
291    """
292    A `Composable` representing an SQL identifier or a dot-separated sequence.
293
294    Identifiers usually represent names of database objects, such as tables or
295    fields. PostgreSQL identifiers follow `different rules`__ than SQL string
296    literals for escaping (e.g. they use double quotes instead of single).
297
298    .. __: https://www.postgresql.org/docs/current/static/sql-syntax-lexical.html# \
299        SQL-SYNTAX-IDENTIFIERS
300
301    Example::
302
303        >>> t1 = sql.Identifier("foo")
304        >>> t2 = sql.Identifier("ba'r")
305        >>> t3 = sql.Identifier('ba"z')
306        >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn))
307        "foo", "ba'r", "ba""z"
308
309    Multiple strings can be passed to the object to represent a qualified name,
310    i.e. a dot-separated sequence of identifiers.
311
312    Example::
313
314        >>> query = sql.SQL("select {} from {}").format(
315        ...     sql.Identifier("table", "field"),
316        ...     sql.Identifier("schema", "table"))
317        >>> print(query.as_string(conn))
318        select "table"."field" from "schema"."table"
319
320    """
321    def __init__(self, *strings):
322        if not strings:
323            raise TypeError("Identifier cannot be empty")
324
325        for s in strings:
326            if not isinstance(s, str):
327                raise TypeError("SQL identifier parts must be strings")
328
329        super().__init__(strings)
330
331    @property
332    def strings(self):
333        """A tuple with the strings wrapped by the `Identifier`."""
334        return self._wrapped
335
336    @property
337    def string(self):
338        """The string wrapped by the `Identifier`.
339        """
340        if len(self._wrapped) == 1:
341            return self._wrapped[0]
342        else:
343            raise AttributeError(
344                "the Identifier wraps more than one than one string")
345
346    def __repr__(self):
347        return f"{self.__class__.__name__}({', '.join(map(repr, self._wrapped))})"
348
349    def as_string(self, context):
350        return '.'.join(ext.quote_ident(s, context) for s in self._wrapped)
351
352
353class Literal(Composable):
354    """
355    A `Composable` representing an SQL value to include in a query.
356
357    Usually you will want to include placeholders in the query and pass values
358    as `~cursor.execute()` arguments. If however you really really need to
359    include a literal value in the query you can use this object.
360
361    The string returned by `!as_string()` follows the normal :ref:`adaptation
362    rules <python-types-adaptation>` for Python objects.
363
364    Example::
365
366        >>> s1 = sql.Literal("foo")
367        >>> s2 = sql.Literal("ba'r")
368        >>> s3 = sql.Literal(42)
369        >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn))
370        'foo', 'ba''r', 42
371
372    """
373    @property
374    def wrapped(self):
375        """The object wrapped by the `!Literal`."""
376        return self._wrapped
377
378    def as_string(self, context):
379        # is it a connection or cursor?
380        if isinstance(context, ext.connection):
381            conn = context
382        elif isinstance(context, ext.cursor):
383            conn = context.connection
384        else:
385            raise TypeError("context must be a connection or a cursor")
386
387        a = ext.adapt(self._wrapped)
388        if hasattr(a, 'prepare'):
389            a.prepare(conn)
390
391        rv = a.getquoted()
392        if isinstance(rv, bytes):
393            rv = rv.decode(ext.encodings[conn.encoding])
394
395        return rv
396
397
398class Placeholder(Composable):
399    """A `Composable` representing a placeholder for query parameters.
400
401    If the name is specified, generate a named placeholder (e.g. ``%(name)s``),
402    otherwise generate a positional placeholder (e.g. ``%s``).
403
404    The object is useful to generate SQL queries with a variable number of
405    arguments.
406
407    Examples::
408
409        >>> names = ['foo', 'bar', 'baz']
410
411        >>> q1 = sql.SQL("insert into table ({}) values ({})").format(
412        ...     sql.SQL(', ').join(map(sql.Identifier, names)),
413        ...     sql.SQL(', ').join(sql.Placeholder() * len(names)))
414        >>> print(q1.as_string(conn))
415        insert into table ("foo", "bar", "baz") values (%s, %s, %s)
416
417        >>> q2 = sql.SQL("insert into table ({}) values ({})").format(
418        ...     sql.SQL(', ').join(map(sql.Identifier, names)),
419        ...     sql.SQL(', ').join(map(sql.Placeholder, names)))
420        >>> print(q2.as_string(conn))
421        insert into table ("foo", "bar", "baz") values (%(foo)s, %(bar)s, %(baz)s)
422
423    """
424
425    def __init__(self, name=None):
426        if isinstance(name, str):
427            if ')' in name:
428                raise ValueError(f"invalid name: {name!r}")
429
430        elif name is not None:
431            raise TypeError(f"expected string or None as name, got {name!r}")
432
433        super().__init__(name)
434
435    @property
436    def name(self):
437        """The name of the `!Placeholder`."""
438        return self._wrapped
439
440    def __repr__(self):
441        if self._wrapped is None:
442            return f"{self.__class__.__name__}()"
443        else:
444            return f"{self.__class__.__name__}({self._wrapped!r})"
445
446    def as_string(self, context):
447        if self._wrapped is not None:
448            return f"%({self._wrapped})s"
449        else:
450            return "%s"
451
452
453# Literals
454NULL = SQL("NULL")
455DEFAULT = SQL("DEFAULT")
456