1""" 2SQL composition utility module 3""" 4 5# Copyright (C) 2020-2021 The Psycopg Team 6 7import codecs 8import string 9from abc import ABC, abstractmethod 10from typing import Any, Iterator, Iterable, List, Optional, Sequence, Union 11 12from .pq import Escaping 13from .abc import AdaptContext 14from .adapt import Transformer, PyFormat 15from ._encodings import pgconn_encoding 16 17 18def quote(obj: Any, context: Optional[AdaptContext] = None) -> str: 19 """ 20 Adapt a Python object to a quoted SQL string. 21 22 Use this function only if you absolutely want to convert a Python string to 23 an SQL quoted literal to use e.g. to generate batch SQL and you won't have 24 a connection avaliable when you will need to use it. 25 26 This function is relatively inefficient, because it doesn't cache the 27 adaptation rules. If you pass a *context* you can adapt the adaptation 28 rules used, otherwise only global rules are used. 29 30 """ 31 return Literal(obj).as_string(context) 32 33 34class Composable(ABC): 35 """ 36 Abstract base class for objects that can be used to compose an SQL string. 37 38 `!Composable` objects can be passed directly to 39 `~psycopg.Cursor.execute()`, `~psycopg.Cursor.executemany()`, 40 `~psycopg.Cursor.copy()` in place of the query string. 41 42 `!Composable` objects can be joined using the ``+`` operator: the result 43 will be a `Composed` instance containing the objects joined. The operator 44 ``*`` is also supported with an integer argument: the result is a 45 `!Composed` instance containing the left argument repeated as many times as 46 requested. 47 """ 48 49 def __init__(self, obj: Any): 50 self._obj = obj 51 52 def __repr__(self) -> str: 53 return f"{self.__class__.__name__}({self._obj!r})" 54 55 @abstractmethod 56 def as_bytes(self, context: Optional[AdaptContext]) -> bytes: 57 """ 58 Return the value of the object as bytes. 59 60 :param context: the context to evaluate the object into. 61 :type context: `connection` or `cursor` 62 63 The method is automatically invoked by `~psycopg.Cursor.execute()`, 64 `~psycopg.Cursor.executemany()`, `~psycopg.Cursor.copy()` if a 65 `!Composable` is passed instead of the query string. 66 67 """ 68 raise NotImplementedError 69 70 def as_string(self, context: Optional[AdaptContext]) -> str: 71 """ 72 Return the value of the object as string. 73 74 :param context: the context to evaluate the string into. 75 :type context: `connection` or `cursor` 76 77 """ 78 conn = context.connection if context else None 79 enc = pgconn_encoding(conn.pgconn) if conn else "utf-8" 80 b = self.as_bytes(context) 81 if isinstance(b, bytes): 82 return b.decode(enc) 83 else: 84 # buffer object 85 return codecs.lookup(enc).decode(b)[0] 86 87 def __add__(self, other: "Composable") -> "Composed": 88 if isinstance(other, Composed): 89 return Composed([self]) + other 90 if isinstance(other, Composable): 91 return Composed([self]) + Composed([other]) 92 else: 93 return NotImplemented 94 95 def __mul__(self, n: int) -> "Composed": 96 return Composed([self] * n) 97 98 def __eq__(self, other: Any) -> bool: 99 return type(self) is type(other) and self._obj == other._obj 100 101 def __ne__(self, other: Any) -> bool: 102 return not self.__eq__(other) 103 104 105class Composed(Composable): 106 """ 107 A `Composable` object made of a sequence of `!Composable`. 108 109 The object is usually created using `!Composable` operators and methods. 110 However it is possible to create a `!Composed` directly specifying a 111 sequence of objects as arguments: if they are not `!Composable` they will 112 be wrapped in a `Literal`. 113 114 Example:: 115 116 >>> comp = sql.Composed( 117 ... [sql.SQL("INSERT INTO "), sql.Identifier("table")]) 118 >>> print(comp.as_string(conn)) 119 INSERT INTO "table" 120 121 `!Composed` objects are iterable (so they can be used in `SQL.join` for 122 instance). 123 """ 124 125 _obj: List[Composable] 126 127 def __init__(self, seq: Sequence[Any]): 128 seq = [ 129 obj if isinstance(obj, Composable) else Literal(obj) for obj in seq 130 ] 131 super().__init__(seq) 132 133 def as_bytes(self, context: Optional[AdaptContext]) -> bytes: 134 return b"".join(obj.as_bytes(context) for obj in self._obj) 135 136 def __iter__(self) -> Iterator[Composable]: 137 return iter(self._obj) 138 139 def __add__(self, other: Composable) -> "Composed": 140 if isinstance(other, Composed): 141 return Composed(self._obj + other._obj) 142 if isinstance(other, Composable): 143 return Composed(self._obj + [other]) 144 else: 145 return NotImplemented 146 147 def join(self, joiner: Union["SQL", str]) -> "Composed": 148 """ 149 Return a new `!Composed` interposing the *joiner* with the `!Composed` items. 150 151 The *joiner* must be a `SQL` or a string which will be interpreted as 152 an `SQL`. 153 154 Example:: 155 156 >>> fields = sql.Identifier('foo') + sql.Identifier('bar') # a Composed 157 >>> print(fields.join(', ').as_string(conn)) 158 "foo", "bar" 159 160 """ 161 if isinstance(joiner, str): 162 joiner = SQL(joiner) 163 elif not isinstance(joiner, SQL): 164 raise TypeError( 165 f"Composed.join() argument must be strings or SQL," 166 f" got {joiner!r} instead" 167 ) 168 169 return joiner.join(self._obj) 170 171 172class SQL(Composable): 173 """ 174 A `Composable` representing a snippet of SQL statement. 175 176 `!SQL` exposes `join()` and `format()` methods useful to create a template 177 where to merge variable parts of a query (for instance field or table 178 names). 179 180 The *string* doesn't undergo any form of escaping, so it is not suitable to 181 represent variable identifiers or values: you should only use it to pass 182 constant strings representing templates or snippets of SQL statements; use 183 other objects such as `Identifier` or `Literal` to represent variable 184 parts. 185 186 Example:: 187 188 >>> query = sql.SQL("SELECT {0} FROM {1}").format( 189 ... sql.SQL(', ').join([sql.Identifier('foo'), sql.Identifier('bar')]), 190 ... sql.Identifier('table')) 191 >>> print(query.as_string(conn)) 192 SELECT "foo", "bar" FROM "table" 193 """ 194 195 _obj: str 196 _formatter = string.Formatter() 197 198 def __init__(self, obj: str): 199 super().__init__(obj) 200 if not isinstance(obj, str): 201 raise TypeError(f"SQL values must be strings, got {obj!r} instead") 202 203 def as_string(self, context: Optional[AdaptContext]) -> str: 204 return self._obj 205 206 def as_bytes(self, context: Optional[AdaptContext]) -> bytes: 207 enc = "utf-8" 208 if context: 209 conn = context.connection 210 if conn: 211 enc = pgconn_encoding(conn.pgconn) 212 return self._obj.encode(enc) 213 214 def format(self, *args: Any, **kwargs: Any) -> Composed: 215 """ 216 Merge `Composable` objects into a template. 217 218 :param args: parameters to replace to numbered (``{0}``, ``{1}``) or 219 auto-numbered (``{}``) placeholders 220 :param kwargs: parameters to replace to named (``{name}``) placeholders 221 :return: the union of the `!SQL` string with placeholders replaced 222 :rtype: `Composed` 223 224 The method is similar to the Python `str.format()` method: the string 225 template supports auto-numbered (``{}``), numbered (``{0}``, 226 ``{1}``...), and named placeholders (``{name}``), with positional 227 arguments replacing the numbered placeholders and keywords replacing 228 the named ones. However placeholder modifiers (``{0!r}``, ``{0:<10}``) 229 are not supported. 230 231 If a `!Composable` objects is passed to the template it will be merged 232 according to its `as_string()` method. If any other Python object is 233 passed, it will be wrapped in a `Literal` object and so escaped 234 according to SQL rules. 235 236 Example:: 237 238 >>> print(sql.SQL("SELECT * FROM {} WHERE {} = %s") 239 ... .format(sql.Identifier('people'), sql.Identifier('id')) 240 ... .as_string(conn)) 241 SELECT * FROM "people" WHERE "id" = %s 242 243 >>> print(sql.SQL("SELECT * FROM {tbl} WHERE name = {name}") 244 ... .format(tbl=sql.Identifier('people'), name="O'Rourke")) 245 ... .as_string(conn)) 246 SELECT * FROM "people" WHERE name = 'O''Rourke' 247 248 """ 249 rv: List[Composable] = [] 250 autonum: Optional[int] = 0 251 for pre, name, spec, conv in self._formatter.parse(self._obj): 252 if spec: 253 raise ValueError("no format specification supported by SQL") 254 if conv: 255 raise ValueError("no format conversion supported by SQL") 256 if pre: 257 rv.append(SQL(pre)) 258 259 if name is None: 260 continue 261 262 if name.isdigit(): 263 if autonum: 264 raise ValueError( 265 "cannot switch from automatic field numbering to manual" 266 ) 267 rv.append(args[int(name)]) 268 autonum = None 269 270 elif not name: 271 if autonum is None: 272 raise ValueError( 273 "cannot switch from manual field numbering to automatic" 274 ) 275 rv.append(args[autonum]) 276 autonum += 1 277 278 else: 279 rv.append(kwargs[name]) 280 281 return Composed(rv) 282 283 def join(self, seq: Iterable[Composable]) -> Composed: 284 """ 285 Join a sequence of `Composable`. 286 287 :param seq: the elements to join. 288 :type seq: iterable of `!Composable` 289 290 Use the `!SQL` object's *string* to separate the elements in *seq*. 291 Note that `Composed` objects are iterable too, so they can be used as 292 argument for this method. 293 294 Example:: 295 296 >>> snip = sql.SQL(', ').join( 297 ... sql.Identifier(n) for n in ['foo', 'bar', 'baz']) 298 >>> print(snip.as_string(conn)) 299 "foo", "bar", "baz" 300 """ 301 rv = [] 302 it = iter(seq) 303 try: 304 rv.append(next(it)) 305 except StopIteration: 306 pass 307 else: 308 for i in it: 309 rv.append(self) 310 rv.append(i) 311 312 return Composed(rv) 313 314 315class Identifier(Composable): 316 """ 317 A `Composable` representing an SQL identifier or a dot-separated sequence. 318 319 Identifiers usually represent names of database objects, such as tables or 320 fields. PostgreSQL identifiers follow `different rules`__ than SQL string 321 literals for escaping (e.g. they use double quotes instead of single). 322 323 .. __: https://www.postgresql.org/docs/current/sql-syntax-lexical.html# \ 324 SQL-SYNTAX-IDENTIFIERS 325 326 Example:: 327 328 >>> t1 = sql.Identifier("foo") 329 >>> t2 = sql.Identifier("ba'r") 330 >>> t3 = sql.Identifier('ba"z') 331 >>> print(sql.SQL(', ').join([t1, t2, t3]).as_string(conn)) 332 "foo", "ba'r", "ba""z" 333 334 Multiple strings can be passed to the object to represent a qualified name, 335 i.e. a dot-separated sequence of identifiers. 336 337 Example:: 338 339 >>> query = sql.SQL("SELECT {} FROM {}").format( 340 ... sql.Identifier("table", "field"), 341 ... sql.Identifier("schema", "table")) 342 >>> print(query.as_string(conn)) 343 SELECT "table"."field" FROM "schema"."table" 344 345 """ 346 347 _obj: Sequence[str] 348 349 def __init__(self, *strings: str): 350 # init super() now to make the __repr__ not explode in case of error 351 super().__init__(strings) 352 353 if not strings: 354 raise TypeError("Identifier cannot be empty") 355 356 for s in strings: 357 if not isinstance(s, str): 358 raise TypeError( 359 f"SQL identifier parts must be strings, got {s!r} instead" 360 ) 361 362 def __repr__(self) -> str: 363 return f"{self.__class__.__name__}({', '.join(map(repr, self._obj))})" 364 365 def as_bytes(self, context: Optional[AdaptContext]) -> bytes: 366 conn = context.connection if context else None 367 if not conn: 368 raise ValueError("a connection is necessary for Identifier") 369 esc = Escaping(conn.pgconn) 370 enc = pgconn_encoding(conn.pgconn) 371 escs = [esc.escape_identifier(s.encode(enc)) for s in self._obj] 372 return b".".join(escs) 373 374 375class Literal(Composable): 376 """ 377 A `Composable` representing an SQL value to include in a query. 378 379 Usually you will want to include placeholders in the query and pass values 380 as `~cursor.execute()` arguments. If however you really really need to 381 include a literal value in the query you can use this object. 382 383 The string returned by `!as_string()` follows the normal :ref:`adaptation 384 rules <types-adaptation>` for Python objects. 385 386 Example:: 387 388 >>> s1 = sql.Literal("foo") 389 >>> s2 = sql.Literal("ba'r") 390 >>> s3 = sql.Literal(42) 391 >>> print(sql.SQL(', ').join([s1, s2, s3]).as_string(conn)) 392 'foo', 'ba''r', 42 393 394 """ 395 396 def as_bytes(self, context: Optional[AdaptContext]) -> bytes: 397 tx = Transformer(context) 398 dumper = tx.get_dumper(self._obj, PyFormat.TEXT) 399 return dumper.quote(self._obj) 400 401 402class Placeholder(Composable): 403 """A `Composable` representing a placeholder for query parameters. 404 405 If the name is specified, generate a named placeholder (e.g. ``%(name)s``, 406 ``%(name)b``), otherwise generate a positional placeholder (e.g. ``%s``, 407 ``%b``). 408 409 The object is useful to generate SQL queries with a variable number of 410 arguments. 411 412 Examples:: 413 414 >>> names = ['foo', 'bar', 'baz'] 415 416 >>> q1 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format( 417 ... sql.SQL(', ').join(map(sql.Identifier, names)), 418 ... sql.SQL(', ').join(sql.Placeholder() * len(names))) 419 >>> print(q1.as_string(conn)) 420 INSERT INTO my_table ("foo", "bar", "baz") VALUES (%s, %s, %s) 421 422 >>> q2 = sql.SQL("INSERT INTO my_table ({}) VALUES ({})").format( 423 ... sql.SQL(', ').join(map(sql.Identifier, names)), 424 ... sql.SQL(', ').join(map(sql.Placeholder, names))) 425 >>> print(q2.as_string(conn)) 426 INSERT INTO my_table ("foo", "bar", "baz") VALUES (%(foo)s, %(bar)s, %(baz)s) 427 428 """ 429 430 def __init__(self, name: str = "", format: PyFormat = PyFormat.AUTO): 431 super().__init__(name) 432 if not isinstance(name, str): 433 raise TypeError(f"expected string as name, got {name!r}") 434 435 if ")" in name: 436 raise ValueError(f"invalid name: {name!r}") 437 438 self._format = format 439 440 def __repr__(self) -> str: 441 parts = [] 442 if self._obj: 443 parts.append(repr(self._obj)) 444 if self._format != PyFormat.AUTO: 445 parts.append(f"format={PyFormat(self._format).name}") 446 447 return f"{self.__class__.__name__}({', '.join(parts)})" 448 449 def as_string(self, context: Optional[AdaptContext]) -> str: 450 code = self._format 451 return f"%({self._obj}){code}" if self._obj else f"%{code}" 452 453 def as_bytes(self, context: Optional[AdaptContext]) -> bytes: 454 conn = context.connection if context else None 455 enc = pgconn_encoding(conn.pgconn) if conn else "utf-8" 456 return self.as_string(context).encode(enc) 457 458 459# Literals 460NULL = SQL("NULL") 461DEFAULT = SQL("DEFAULT") 462