1# -*- coding: utf-8 -*- 2# :Project: pglast -- Serialization logic 3# :Created: mer 02 ago 2017 15:46:11 CEST 4# :Author: Lele Gaifax <lele@metapensiero.it> 5# :License: GNU General Public License version 3 or later 6# :Copyright: © 2017, 2018, 2019 Lele Gaifax 7# 8 9from contextlib import contextmanager 10from io import StringIO 11from re import match 12 13from .error import Error 14from .node import List, Node, Scalar 15from .keywords import RESERVED_KEYWORDS 16from .parser import parse_plpgsql, parse_sql 17 18 19NODE_PRINTERS = {} 20"Registry of specialized printers, keyed by their `tag`." 21 22 23SPECIAL_FUNCTIONS = {} 24"Registry of specialized function printers, keyed by their qualified name." 25 26 27class PrinterAlreadyPresentError(Error): 28 "Exception raised trying to register another function for a tag already present." 29 30 31def get_printer_for_node_tag(parent_node_tag, node_tag): 32 """Get specific printer implementation for given `node_tag`. 33 34 If there is a more specific printer for it, when it's inside a particular 35 `parent_node_tag`, return that instead. 36 """ 37 38 try: 39 return NODE_PRINTERS[(parent_node_tag, node_tag)] 40 except KeyError: 41 try: 42 return NODE_PRINTERS[node_tag] 43 except KeyError: 44 raise NotImplementedError("Printer for node %r is not implemented yet" 45 % node_tag) 46 47 48def node_printer(*node_tags, override=False): 49 """Decorator to register a specific printer implementation for given `node_tag`. 50 51 :param \*node_tags: one or two node tags 52 :param bool override: 53 when ``True`` the function will be registered even if already present in the 54 :data:`NODE_PRINTERS` registry 55 56 When `node_tags` contains a single item then the decorated function is the *generic* one, 57 and it will be registered in :data:`NODE_PRINTERS` with that key alone. Otherwise it must 58 contain two elements: the first may be either a scalar value or a sequence of parent tags, 59 and the function will be registered under the key ``(parent_tag, tag)``. 60 """ 61 62 def decorator(impl): 63 if len(node_tags) == 1: 64 parent_tags = (None,) 65 tag = node_tags[0] 66 elif len(node_tags) == 2: 67 parent_tags, tag = node_tags 68 if not isinstance(parent_tags, (list, tuple)): 69 parent_tags = (parent_tags,) 70 else: 71 raise ValueError("Must specify one or two tags, got %d instead" % len(node_tags)) 72 73 for parent_tag in parent_tags: 74 t = tag if parent_tag is None else (parent_tag, tag) 75 if not override and t in NODE_PRINTERS: 76 raise PrinterAlreadyPresentError("A printer is already registered for tag %r" 77 % t) 78 NODE_PRINTERS[t] = impl 79 return impl 80 return decorator 81 82 83def special_function(name, override=False): 84 """Decorator to declare a particular PostgreSQL function `name` as *special*, with a 85 specific printer. 86 87 :param: str name: the qualified name of the PG function 88 :param bool override: when ``True`` the function will be registered even if already 89 present in the :data:`SPECIAL_FUNCTIONS` registry 90 """ 91 92 def decorator(impl): 93 if not override and name in SPECIAL_FUNCTIONS: 94 raise PrinterAlreadyPresentError("A printer is already registered for function %r" 95 % name) 96 SPECIAL_FUNCTIONS[name] = impl 97 return impl 98 return decorator 99 100 101class OutputStream(StringIO): 102 "A stream that has a concept of a *pending separator* between consecutive writes." 103 104 def __init__(self): 105 super().__init__() 106 self.pending_separator = False 107 self.last_emitted_char = ' ' 108 109 def separator(self): 110 """Possibly insert a single whitespace character before next output. 111 112 When the last character written is not a space, set the `pending_separator` flag to 113 ``True``: the next call to :meth:`write` will prepend a single whitespace to its 114 argument if that begins with an alphanumeric character. 115 """ 116 117 if not self.last_emitted_char.isspace(): 118 self.pending_separator = True 119 120 def maybe_write_space(self, nextc=None, _special_chars=set("""_*+/-"'=<>$@""")): 121 """Emit a space when needed. 122 123 :param nextc: either None or the next character that will be written 124 :return: the number of characters written to the stream 125 126 If the last character written was not a space, and `nextc` is either ``None`` or 127 a *special character*, then emit a single whitespace. 128 """ 129 130 if not self.last_emitted_char.isspace(): 131 if nextc is None or nextc.isalnum() or nextc in _special_chars: 132 return self.write(' ') 133 return 0 134 135 def write(self, s): 136 """Emit string `s`. 137 138 :param str s: the string to emit 139 :return: the number of characters written to the stream 140 141 When `s` is not empty and `pending_separator` is ``True`` and the first character of 142 `s` is alphanumeric, emit a single whitespace before writing out `s` and then reset 143 `pending_separator` to ``False``. 144 """ 145 146 count = 0 147 if s: 148 if self.pending_separator: 149 if s != ' ': 150 self.maybe_write_space(s[0]) 151 self.pending_separator = False 152 count = super().write(s) 153 self.last_emitted_char = s[-1] 154 155 return count 156 157 def writes(self, s): 158 "Shortcut for ``self.write(s); self.separator()``." 159 160 count = self.write(s) 161 self.separator() 162 return count 163 164 def swrite(self, s): 165 "Shortcut for ``self.maybe_write_space(s[0]); self.write(s)``." 166 167 count = self.maybe_write_space(s[0]) 168 return count + self.write(s) 169 170 def swrites(self, s): 171 "Shortcut for ``self.swrite(s); self.separator()``." 172 173 count = self.swrite(s) 174 self.separator() 175 return count 176 177 178class RawStream(OutputStream): 179 """Basic SQL parse tree writer. 180 181 :param int expression_level: 182 start the stream with the given expression level depth, 0 by default 183 :param bool separate_statements: 184 ``True`` by default, tells whether multiple statements shall be separated by an 185 empty line 186 :param bool special_functions: 187 ``False`` by default, when ``True`` some functions are treated in a special way and 188 emitted as equivalent constructs 189 :param bool comma_at_eoln: 190 ``False`` by default, when ``True`` put the comma right after each item instead of 191 at the beginning of the *next* item line 192 193 This augments :class:`OutputStream` and implements the basic machinery needed to serialize 194 the *parse tree* produced by :func:`~.parser.parse_sql()` back to a textual representation, 195 without any adornment. 196 """ 197 198 def __init__(self, expression_level=0, separate_statements=True, special_functions=False, 199 comma_at_eoln=False): 200 super().__init__() 201 self.expression_level = expression_level 202 self.separate_statements = separate_statements 203 self.special_functions = special_functions 204 self.comma_at_eoln = comma_at_eoln 205 self.current_column = 0 206 207 def __call__(self, sql, plpgsql=False): 208 """Main entry point: execute :meth:`print_node` on each statement in `sql`. 209 210 :param sql: either the source SQL in textual form, or a :class:`~.node.Node` instance 211 :param bool plpgsql: whether `sql` is really a ``plpgsql`` statement 212 :returns: a string with the equivalent SQL obtained by serializing the syntax tree 213 """ 214 215 if isinstance(sql, str): 216 sql = Node(parse_plpgsql(sql) if plpgsql else parse_sql(sql)) 217 elif isinstance(sql, Node): 218 sql = [sql] 219 elif not isinstance(sql, List): 220 raise ValueError("Unexpected value for 'sql', must be either a string," 221 " a Node instance or a List instance, got %r" % type(sql)) 222 223 first = True 224 for statement in sql: 225 if first: 226 first = False 227 else: 228 self.write(';') 229 self.newline() 230 if self.separate_statements: 231 self.newline() 232 self.print_node(statement) 233 return self.getvalue() 234 235 def dedent(self): 236 "Do nothing, shall be overridden by the prettifier subclass." 237 238 def get_printer_for_function(self, name): 239 """Look for a specific printer for function `name` in :data:`SPECIAL_FUNCTIONS`. 240 241 :param str name: the qualified name of the function 242 :returns: either ``None`` or a callable 243 244 When the option `special_functions` is ``True``, return the printer function associated 245 with `name`, if present. In all other cases return ``None``. 246 """ 247 248 if self.special_functions: 249 return SPECIAL_FUNCTIONS.get(name) 250 251 def indent(self, amount=0, relative=True): 252 "Do nothing, shall be overridden by the prettifier subclass." 253 254 def newline(self): 255 "Emit a single whitespace, shall be overridden by the prettifier subclass." 256 257 self.separator() 258 259 def space(self, count=1): 260 "Emit a single whitespace, shall be overridden by the prettifier subclass." 261 262 self.separator() 263 264 @contextmanager 265 def push_indent(self, amount=0, relative=True): 266 "Create a no-op context manager, shall be overridden by the prettifier subclass." 267 268 yield 269 270 @contextmanager 271 def expression(self): 272 "Create a context manager that will wrap subexpressions within parentheses." 273 274 self.expression_level += 1 275 if self.expression_level > 1: 276 self.write('(') 277 yield 278 if self.expression_level > 1: 279 self.write(')') 280 self.expression_level -= 1 281 282 def _concat_nodes(self, nodes, sep=' ', are_names=False): 283 """Concatenate given `nodes`, using `sep` as the separator. 284 285 :param scalars: a sequence of nodes 286 :param str sep: the separator between them 287 :param bool are_names: 288 whether the nodes are actually *names*, which possibly require to be enclosed 289 between double-quotes 290 :returns: a string 291 292 Use a temporary :class:`RawStream` instance to print the list of nodes and return the 293 result. 294 """ 295 296 rawstream = RawStream(expression_level=self.expression_level) 297 rawstream.print_list(nodes, sep, are_names=are_names, standalone_items=False) 298 return rawstream.getvalue() 299 300 def _write_quoted_string(self, s): 301 "Emit the `s` as a single-quoted literal constant." 302 303 self.write("'%s'" % s.replace("'", "''")) 304 305 def _print_scalar(self, node, is_name, is_symbol): 306 "Print the scalar `node`, special-casing string literals." 307 308 value = node.value 309 if is_symbol: 310 self.write(value) 311 elif is_name: 312 # The `scalar` represent a name of a column/table/alias: when any of its 313 # characters is not a lower case letter, a digit or underscore, it must be 314 # double quoted 315 if not match(r'[a-z_][a-z0-9_]*$', value) or value in RESERVED_KEYWORDS: 316 value = '"%s"' % value.replace('"', '""') 317 self.write(value) 318 elif node.parent_node.node_tag == 'String': 319 self._write_quoted_string(value) 320 else: 321 self.write(str(value)) 322 323 def print_name(self, nodes, sep='.'): 324 "Helper method, execute :meth:`print_node` or :meth:`print_list` as needed." 325 326 if isinstance(nodes, (List, list)): 327 self.print_list(nodes, sep, standalone_items=False, are_names=True) 328 else: 329 self.print_node(nodes, is_name=True) 330 331 def print_symbol(self, nodes, sep='.'): 332 "Helper method, execute :meth:`print_node` or :meth:`print_list` as needed." 333 334 if isinstance(nodes, (List, list)): 335 self.print_list(nodes, sep, standalone_items=False, are_names=True, is_symbol=True) 336 else: 337 self.print_node(nodes, is_name=True, is_symbol=True) 338 339 def print_node(self, node, is_name=False, is_symbol=False): 340 """Lookup the specific printer for the given `node` and execute it. 341 342 :param node: an instance of :class:`~.node.Node` or :class:`~.node.Scalar` 343 :param bool is_name: 344 whether this is a *name* of something, that may need to be double quoted 345 :param bool is_symbol: 346 whether this is the name of an *operator*, should not be double quoted 347 """ 348 349 if isinstance(node, Scalar): 350 self._print_scalar(node, is_name, is_symbol) 351 else: 352 parent_node_tag = node.parent_node and node.parent_node.node_tag 353 printer = get_printer_for_node_tag(parent_node_tag, node.node_tag) 354 if is_name and node.node_tag == 'String': 355 printer(node, self, is_name=is_name, is_symbol=is_symbol) 356 else: 357 printer(node, self) 358 self.separator() 359 360 def _print_items(self, items, sep, newline, are_names=False, is_symbol=False): 361 last = len(items) - 1 362 for idx, item in enumerate(items): 363 if idx > 0: 364 if sep == ',' and self.comma_at_eoln: 365 self.write(sep) 366 if newline: 367 self.newline() 368 else: 369 self.write(' ') 370 else: 371 if not are_names: 372 if newline: 373 self.newline() 374 if sep: 375 self.write(sep) 376 if sep != '.': 377 self.write(' ') 378 self.print_node(item, is_name=are_names, is_symbol=is_symbol and idx == last) 379 380 def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None, 381 are_names=False, is_symbol=False): 382 """Execute :meth:`print_node` on all the `nodes`, separating them with `sep`. 383 384 :param nodes: a sequence of :class:`~.node.Node` instances 385 :param str sep: the separator between them 386 :param bool relative_indent: 387 if given, the relative amount of indentation to apply before the first item, by 388 default computed automatically from the length of the separator `sep` 389 :param bool standalone_items: whether a newline will be emitted before each item 390 :param bool are_names: 391 whether the nodes are actually *names*, which possibly require to be enclosed 392 between double-quotes 393 :param bool is_symbol: 394 whether the nodes are actually a *symbol* such as an *operator name*, in which 395 case the last one must be printed verbatim (e.g. ``"MySchema".===``) 396 """ 397 398 if relative_indent is None: 399 if are_names or is_symbol: 400 relative_indent = 0 401 else: 402 relative_indent = (-(len(sep) + 1) 403 if sep and (sep != ',' or not self.comma_at_eoln) 404 else 0) 405 406 if standalone_items is None: 407 standalone_items = not all(isinstance(n, Node) 408 and n.node_tag in ('A_Const', 'ColumnRef', 409 'SetToDefault', 'RangeVar') 410 for n in nodes) 411 412 with self.push_indent(relative_indent): 413 self._print_items(nodes, sep, standalone_items, are_names=are_names, 414 is_symbol=is_symbol) 415 416 def print_lists(self, lists, sep=',', relative_indent=None, standalone_items=None, 417 are_names=False, sublist_open='(', sublist_close=')', sublist_sep=',', 418 sublist_relative_indent=None): 419 """Execute :meth:`print_list` on all the `lists` items. 420 421 :param lists: a sequence of sequences of :class:`~.node.Node` instances 422 :param str sep: passed as is to :meth:`print_list` 423 :param bool relative_indent: passed as is to :meth:`print_list` 424 :param bool standalone_items: passed as is to :meth:`print_list` 425 :param bool are_names: passed as is to :meth:`print_list` 426 :param str sublist_open: the string that will be emitted before each sublist 427 :param str sublist_close: the string that will be emitted after each sublist 428 :param str sublist_sep: the separator between them each sublist 429 :param bool sublist_relative_indent: 430 if given, the relative amount of indentation to apply before the first sublist, 431 by default computed automatically from the length of the separator `sublist_sep` 432 """ 433 434 if sublist_relative_indent is None: 435 sublist_relative_indent = (-(len(sublist_sep) + 1) 436 if sublist_sep and (sublist_sep != ',' 437 or not self.comma_at_eoln) 438 else 0) 439 440 with self.push_indent(sublist_relative_indent): 441 self.write(sublist_open) 442 first = True 443 for lst in lists: 444 if first: 445 first = False 446 else: 447 if self.comma_at_eoln: 448 self.write(sublist_sep) 449 self.newline() 450 self.write(sublist_open) 451 else: 452 self.newline() 453 self.write(sublist_sep) 454 self.write(' ') 455 self.write(sublist_open) 456 self.print_list(lst, sep, relative_indent, standalone_items, are_names) 457 self.write(sublist_close) 458 459 460class IndentedStream(RawStream): 461 """Indented SQL parse tree writer. 462 463 :param int compact_lists_margin: 464 an integer value that, if given, is used to print lists on a single line, when they 465 do not exceed the given margin on the right 466 :param int split_string_literals_threshold: 467 an integer value that, if given, is used as the threshold beyond that a string 468 literal gets splitted in successive chunks of that length 469 :param \*\*options: other options accepted by :class:`RawStream` 470 471 This augments :class:`RawStream` to emit a prettified representation of a *parse tree*. 472 """ 473 474 def __init__(self, compact_lists_margin=None, split_string_literals_threshold=None, 475 **options): 476 super().__init__(**options) 477 self.compact_lists_margin = compact_lists_margin 478 self.split_string_literals_threshold = split_string_literals_threshold 479 self.current_indent = 0 480 self.indentation_stack = [] 481 482 def dedent(self): 483 "Pop the indentation level from the stack and set `current_indent` to that." 484 485 self.current_indent = self.indentation_stack.pop() 486 487 def indent(self, amount=0, relative=True): 488 """Push current indentation level to the stack, then set it adding `amount` to the 489 `current_column` if `relative` is ``True`` otherwise to `current_indent`. 490 """ 491 492 self.indentation_stack.append(self.current_indent) 493 base_indent = (self.current_column if relative else self.current_indent) 494 assert base_indent + amount >= 0 495 self.current_indent = base_indent + amount 496 497 @contextmanager 498 def push_indent(self, amount=0, relative=True): 499 """Create a context manager that calls :meth:`indent` and :meth:`dedent` around a block 500 of code. 501 502 This is just an helper to simplify code that adjust the indentation level: 503 504 .. code-block:: python 505 506 with output.push_indent(4): 507 # code that emits something with the new indentation 508 """ 509 510 if self.pending_separator and relative: 511 amount += 1 512 if self.current_column == 0 and relative: 513 amount += self.current_indent 514 self.indent(amount, relative) 515 yield 516 self.dedent() 517 518 def newline(self): 519 "Emit a newline." 520 521 self.write('\n') 522 523 def space(self, count=1): 524 "Emit consecutive spaces." 525 526 self.write(' '*count) 527 528 def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None, 529 are_names=False, is_symbol=False): 530 """Execute :meth:`print_node` on all the `nodes`, separating them with `sep`. 531 532 :param nodes: a sequence of :class:`~.node.Node` instances 533 :param str sep: the separator between them 534 :param bool relative_indent: 535 if given, the relative amount of indentation to apply before the first item, by 536 default computed automatically from the length of the separator `sep` 537 :param bool standalone_items: whether a newline will be emitted before each item 538 :param bool are_names: 539 whether the nodes are actually *names*, which possibly require to be enclosed 540 between double-quotes 541 :param bool is_symbol: 542 whether the nodes are actually an *operator name*, in which case the last one 543 must be printed verbatim (such as ``"MySchema".===``) 544 """ 545 546 if standalone_items is None: 547 clm = self.compact_lists_margin 548 if clm is not None and clm > 0: 549 rawlist = self._concat_nodes(nodes, sep, are_names) 550 if self.current_column + len(rawlist) < clm: 551 self.write(rawlist) 552 return 553 554 standalone_items = not all( 555 (isinstance(n, Node) 556 and n.node_tag in ('A_Const', 'ColumnRef', 'SetToDefault', 'RangeVar')) 557 for n in nodes) 558 559 if (((sep != ',' or not self.comma_at_eoln) 560 and len(nodes) > 1 561 and len(sep) > 1 562 and relative_indent is None 563 and not are_names 564 and not is_symbol 565 and standalone_items)): 566 self.write(' '*(len(sep) + 1)) # separator added automatically 567 568 super().print_list(nodes, sep, relative_indent, standalone_items, are_names, is_symbol) 569 570 def _write_quoted_string(self, s): 571 """Possibly split `s` string in successive chunks. 572 573 When the ``split_string_literals_threshold`` option is greater than 0 and the length of 574 `s` exceeds that value, split the string into multiple chunks. 575 """ 576 577 sslt = self.split_string_literals_threshold 578 if sslt is None or sslt <= 0: 579 super()._write_quoted_string(s) 580 else: 581 multiline = '\n' in s 582 if multiline: 583 self.write('E') 584 with self.push_indent(): 585 while True: 586 chunk = s[:sslt] 587 s = s[sslt:] 588 # Avoid splitting on backslash 589 while chunk.endswith("\\"): 590 chunk += s[0] 591 s = s[1:] 592 chunk = chunk.replace("'", "''") 593 if multiline: 594 chunk = chunk.replace("\\", "\\\\") 595 chunk = chunk.replace("\n", "\\n") 596 self.write("'%s'" % chunk) 597 if s: 598 self.newline() 599 else: 600 break 601 602 def write(self, s): 603 """Write string `s` to the stream, adjusting the `current_column` accordingly. 604 605 :param str s: the string to emit 606 :return: the number of characters written to the stream 607 608 If `s` is a newline character (``\\n``) set `current_column` to 0. Otherwise when 609 `current_column` is 0 and `current_indent` is greater than 0 emit a number of 610 whitespaces *before* emitting `s`, to indent it as expected. 611 """ 612 613 if s and s != '\n' and self.current_column == 0 and self.current_indent > 0: 614 self.current_column = super().write(' ' * self.current_indent) 615 616 count = super().write(s) 617 if s == '\n': 618 self.current_column = 0 619 else: 620 self.current_column += count 621 622 return count 623