1# -*- coding: utf-8 -*-
2# :Project:   pglast -- Serialization logic
3# :Created:   mer 02 ago 2017 15:46:11 CEST
4# :Author:    Lele Gaifax <lele@metapensiero.it>
5# :License:   GNU General Public License version 3 or later
6# :Copyright: © 2017, 2018, 2019 Lele Gaifax
7#
8
9from contextlib import contextmanager
10from io import StringIO
11from re import match
12
13from .error import Error
14from .node import List, Node, Scalar
15from .keywords import RESERVED_KEYWORDS
16from .parser import parse_plpgsql, parse_sql
17
18
19NODE_PRINTERS = {}
20"Registry of specialized printers, keyed by their `tag`."
21
22
23SPECIAL_FUNCTIONS = {}
24"Registry of specialized function printers, keyed by their qualified name."
25
26
27class PrinterAlreadyPresentError(Error):
28    "Exception raised trying to register another function for a tag already present."
29
30
31def get_printer_for_node_tag(parent_node_tag, node_tag):
32    """Get specific printer implementation for given `node_tag`.
33
34    If there is a more specific printer for it, when it's inside a particular
35    `parent_node_tag`, return that instead.
36    """
37
38    try:
39        return NODE_PRINTERS[(parent_node_tag, node_tag)]
40    except KeyError:
41        try:
42            return NODE_PRINTERS[node_tag]
43        except KeyError:
44            raise NotImplementedError("Printer for node %r is not implemented yet"
45                                      % node_tag)
46
47
48def node_printer(*node_tags, override=False):
49    """Decorator to register a specific printer implementation for given `node_tag`.
50
51    :param \*node_tags: one or two node tags
52    :param bool override:
53           when ``True`` the function will be registered even if already present in the
54           :data:`NODE_PRINTERS` registry
55
56    When `node_tags` contains a single item then the decorated function is the *generic* one,
57    and it will be registered in :data:`NODE_PRINTERS` with that key alone. Otherwise it must
58    contain two elements: the first may be either a scalar value or a sequence of parent tags,
59    and the function will be registered under the key ``(parent_tag, tag)``.
60    """
61
62    def decorator(impl):
63        if len(node_tags) == 1:
64            parent_tags = (None,)
65            tag = node_tags[0]
66        elif len(node_tags) == 2:
67            parent_tags, tag = node_tags
68            if not isinstance(parent_tags, (list, tuple)):
69                parent_tags = (parent_tags,)
70        else:
71            raise ValueError("Must specify one or two tags, got %d instead" % len(node_tags))
72
73        for parent_tag in parent_tags:
74            t = tag if parent_tag is None else (parent_tag, tag)
75            if not override and t in NODE_PRINTERS:
76                raise PrinterAlreadyPresentError("A printer is already registered for tag %r"
77                                                 % t)
78            NODE_PRINTERS[t] = impl
79        return impl
80    return decorator
81
82
83def special_function(name, override=False):
84    """Decorator to declare a particular PostgreSQL function `name` as *special*, with a
85    specific printer.
86
87    :param: str name: the qualified name of the PG function
88    :param bool override: when ``True`` the function will be registered even if already
89                          present in the :data:`SPECIAL_FUNCTIONS` registry
90    """
91
92    def decorator(impl):
93        if not override and name in SPECIAL_FUNCTIONS:
94            raise PrinterAlreadyPresentError("A printer is already registered for function %r"
95                                             % name)
96        SPECIAL_FUNCTIONS[name] = impl
97        return impl
98    return decorator
99
100
101class OutputStream(StringIO):
102    "A stream that has a concept of a *pending separator* between consecutive writes."
103
104    def __init__(self):
105        super().__init__()
106        self.pending_separator = False
107        self.last_emitted_char = ' '
108
109    def separator(self):
110        """Possibly insert a single whitespace character before next output.
111
112        When the last character written is not a space, set the `pending_separator` flag to
113        ``True``: the next call to :meth:`write` will prepend a single whitespace to its
114        argument if that begins with an alphanumeric character.
115        """
116
117        if not self.last_emitted_char.isspace():
118            self.pending_separator = True
119
120    def maybe_write_space(self, nextc=None, _special_chars=set("""_*+/-"'=<>$@""")):
121        """Emit a space when needed.
122
123        :param nextc: either None or the next character that will be written
124        :return: the number of characters written to the stream
125
126        If the last character written was not a space, and `nextc` is either ``None`` or
127        a *special character*, then emit a single whitespace.
128        """
129
130        if not self.last_emitted_char.isspace():
131            if nextc is None or nextc.isalnum() or nextc in _special_chars:
132                return self.write(' ')
133        return 0
134
135    def write(self, s):
136        """Emit string `s`.
137
138        :param str s: the string to emit
139        :return: the number of characters written to the stream
140
141        When `s` is not empty and `pending_separator` is ``True`` and the first character of
142        `s` is alphanumeric, emit a single whitespace before writing out `s` and then reset
143        `pending_separator` to ``False``.
144        """
145
146        count = 0
147        if s:
148            if self.pending_separator:
149                if s != ' ':
150                    self.maybe_write_space(s[0])
151                self.pending_separator = False
152            count = super().write(s)
153            self.last_emitted_char = s[-1]
154
155        return count
156
157    def writes(self, s):
158        "Shortcut for ``self.write(s); self.separator()``."
159
160        count = self.write(s)
161        self.separator()
162        return count
163
164    def swrite(self, s):
165        "Shortcut for ``self.maybe_write_space(s[0]); self.write(s)``."
166
167        count = self.maybe_write_space(s[0])
168        return count + self.write(s)
169
170    def swrites(self, s):
171        "Shortcut for ``self.swrite(s); self.separator()``."
172
173        count = self.swrite(s)
174        self.separator()
175        return count
176
177
178class RawStream(OutputStream):
179    """Basic SQL parse tree writer.
180
181    :param int expression_level:
182           start the stream with the given expression level depth, 0 by default
183    :param bool separate_statements:
184           ``True`` by default, tells whether multiple statements shall be separated by an
185           empty line
186    :param bool special_functions:
187           ``False`` by default, when ``True`` some functions are treated in a special way and
188           emitted as equivalent constructs
189    :param bool comma_at_eoln:
190           ``False`` by default, when ``True`` put the comma right after each item instead of
191           at the beginning of the *next* item line
192
193    This augments :class:`OutputStream` and implements the basic machinery needed to serialize
194    the *parse tree* produced by :func:`~.parser.parse_sql()` back to a textual representation,
195    without any adornment.
196    """
197
198    def __init__(self, expression_level=0, separate_statements=True, special_functions=False,
199                 comma_at_eoln=False):
200        super().__init__()
201        self.expression_level = expression_level
202        self.separate_statements = separate_statements
203        self.special_functions = special_functions
204        self.comma_at_eoln = comma_at_eoln
205        self.current_column = 0
206
207    def __call__(self, sql, plpgsql=False):
208        """Main entry point: execute :meth:`print_node` on each statement in `sql`.
209
210        :param sql: either the source SQL in textual form, or a :class:`~.node.Node` instance
211        :param bool plpgsql: whether `sql` is really a ``plpgsql`` statement
212        :returns: a string with the equivalent SQL obtained by serializing the syntax tree
213        """
214
215        if isinstance(sql, str):
216            sql = Node(parse_plpgsql(sql) if plpgsql else parse_sql(sql))
217        elif isinstance(sql, Node):
218            sql = [sql]
219        elif not isinstance(sql, List):
220            raise ValueError("Unexpected value for 'sql', must be either a string,"
221                             " a Node instance or a List instance, got %r" % type(sql))
222
223        first = True
224        for statement in sql:
225            if first:
226                first = False
227            else:
228                self.write(';')
229                self.newline()
230                if self.separate_statements:
231                    self.newline()
232            self.print_node(statement)
233        return self.getvalue()
234
235    def dedent(self):
236        "Do nothing, shall be overridden by the prettifier subclass."
237
238    def get_printer_for_function(self, name):
239        """Look for a specific printer for function `name` in :data:`SPECIAL_FUNCTIONS`.
240
241        :param str name: the qualified name of the function
242        :returns: either ``None`` or a callable
243
244        When the option `special_functions` is ``True``, return the printer function associated
245        with `name`, if present. In all other cases return ``None``.
246        """
247
248        if self.special_functions:
249            return SPECIAL_FUNCTIONS.get(name)
250
251    def indent(self, amount=0, relative=True):
252        "Do nothing, shall be overridden by the prettifier subclass."
253
254    def newline(self):
255        "Emit a single whitespace, shall be overridden by the prettifier subclass."
256
257        self.separator()
258
259    def space(self, count=1):
260        "Emit a single whitespace, shall be overridden by the prettifier subclass."
261
262        self.separator()
263
264    @contextmanager
265    def push_indent(self, amount=0, relative=True):
266        "Create a no-op context manager, shall be overridden by the prettifier subclass."
267
268        yield
269
270    @contextmanager
271    def expression(self):
272        "Create a context manager that will wrap subexpressions within parentheses."
273
274        self.expression_level += 1
275        if self.expression_level > 1:
276            self.write('(')
277        yield
278        if self.expression_level > 1:
279            self.write(')')
280        self.expression_level -= 1
281
282    def _concat_nodes(self, nodes, sep=' ', are_names=False):
283        """Concatenate given `nodes`, using `sep` as the separator.
284
285        :param scalars: a sequence of nodes
286        :param str sep: the separator between them
287        :param bool are_names:
288               whether the nodes are actually *names*, which possibly require to be enclosed
289               between double-quotes
290        :returns: a string
291
292        Use a temporary :class:`RawStream` instance to print the list of nodes and return the
293        result.
294        """
295
296        rawstream = RawStream(expression_level=self.expression_level)
297        rawstream.print_list(nodes, sep, are_names=are_names, standalone_items=False)
298        return rawstream.getvalue()
299
300    def _write_quoted_string(self, s):
301        "Emit the `s` as a single-quoted literal constant."
302
303        self.write("'%s'" % s.replace("'", "''"))
304
305    def _print_scalar(self, node, is_name, is_symbol):
306        "Print the scalar `node`, special-casing string literals."
307
308        value = node.value
309        if is_symbol:
310            self.write(value)
311        elif is_name:
312            # The `scalar` represent a name of a column/table/alias: when any of its
313            # characters is not a lower case letter, a digit or underscore, it must be
314            # double quoted
315            if not match(r'[a-z_][a-z0-9_]*$', value) or value in RESERVED_KEYWORDS:
316                value = '"%s"' % value.replace('"', '""')
317            self.write(value)
318        elif node.parent_node.node_tag == 'String':
319            self._write_quoted_string(value)
320        else:
321            self.write(str(value))
322
323    def print_name(self, nodes, sep='.'):
324        "Helper method, execute :meth:`print_node` or :meth:`print_list` as needed."
325
326        if isinstance(nodes, (List, list)):
327            self.print_list(nodes, sep, standalone_items=False, are_names=True)
328        else:
329            self.print_node(nodes, is_name=True)
330
331    def print_symbol(self, nodes, sep='.'):
332        "Helper method, execute :meth:`print_node` or :meth:`print_list` as needed."
333
334        if isinstance(nodes, (List, list)):
335            self.print_list(nodes, sep, standalone_items=False, are_names=True, is_symbol=True)
336        else:
337            self.print_node(nodes, is_name=True, is_symbol=True)
338
339    def print_node(self, node, is_name=False, is_symbol=False):
340        """Lookup the specific printer for the given `node` and execute it.
341
342        :param node: an instance of :class:`~.node.Node` or :class:`~.node.Scalar`
343        :param bool is_name:
344               whether this is a *name* of something, that may need to be double quoted
345        :param bool is_symbol:
346               whether this is the name of an *operator*, should not be double quoted
347        """
348
349        if isinstance(node, Scalar):
350            self._print_scalar(node, is_name, is_symbol)
351        else:
352            parent_node_tag = node.parent_node and node.parent_node.node_tag
353            printer = get_printer_for_node_tag(parent_node_tag, node.node_tag)
354            if is_name and node.node_tag == 'String':
355                printer(node, self, is_name=is_name, is_symbol=is_symbol)
356            else:
357                printer(node, self)
358        self.separator()
359
360    def _print_items(self, items, sep, newline, are_names=False, is_symbol=False):
361        last = len(items) - 1
362        for idx, item in enumerate(items):
363            if idx > 0:
364                if sep == ',' and self.comma_at_eoln:
365                    self.write(sep)
366                    if newline:
367                        self.newline()
368                    else:
369                        self.write(' ')
370                else:
371                    if not are_names:
372                        if newline:
373                            self.newline()
374                    if sep:
375                        self.write(sep)
376                        if sep != '.':
377                            self.write(' ')
378            self.print_node(item, is_name=are_names, is_symbol=is_symbol and idx == last)
379
380    def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None,
381                   are_names=False, is_symbol=False):
382        """Execute :meth:`print_node` on all the `nodes`, separating them with `sep`.
383
384        :param nodes: a sequence of :class:`~.node.Node` instances
385        :param str sep: the separator between them
386        :param bool relative_indent:
387               if given, the relative amount of indentation to apply before the first item, by
388               default computed automatically from the length of the separator `sep`
389        :param bool standalone_items: whether a newline will be emitted before each item
390        :param bool are_names:
391               whether the nodes are actually *names*, which possibly require to be enclosed
392               between double-quotes
393        :param bool is_symbol:
394               whether the nodes are actually a *symbol* such as an *operator name*, in which
395               case the last one must be printed verbatim (e.g. ``"MySchema".===``)
396        """
397
398        if relative_indent is None:
399            if are_names or is_symbol:
400                relative_indent = 0
401            else:
402                relative_indent = (-(len(sep) + 1)
403                                   if sep and (sep != ',' or not self.comma_at_eoln)
404                                   else 0)
405
406        if standalone_items is None:
407            standalone_items = not all(isinstance(n, Node)
408                                       and n.node_tag in ('A_Const', 'ColumnRef',
409                                                          'SetToDefault', 'RangeVar')
410                                       for n in nodes)
411
412        with self.push_indent(relative_indent):
413            self._print_items(nodes, sep, standalone_items, are_names=are_names,
414                              is_symbol=is_symbol)
415
416    def print_lists(self, lists, sep=',', relative_indent=None, standalone_items=None,
417                    are_names=False, sublist_open='(', sublist_close=')', sublist_sep=',',
418                    sublist_relative_indent=None):
419        """Execute :meth:`print_list` on all the `lists` items.
420
421        :param lists: a sequence of sequences of :class:`~.node.Node` instances
422        :param str sep: passed as is to :meth:`print_list`
423        :param bool relative_indent: passed as is to :meth:`print_list`
424        :param bool standalone_items: passed as is to :meth:`print_list`
425        :param bool are_names: passed as is to :meth:`print_list`
426        :param str sublist_open: the string that will be emitted before each sublist
427        :param str sublist_close: the string that will be emitted after each sublist
428        :param str sublist_sep: the separator between them each sublist
429        :param bool sublist_relative_indent:
430               if given, the relative amount of indentation to apply before the first sublist,
431               by default computed automatically from the length of the separator `sublist_sep`
432        """
433
434        if sublist_relative_indent is None:
435            sublist_relative_indent = (-(len(sublist_sep) + 1)
436                                       if sublist_sep and (sublist_sep != ','
437                                                           or not self.comma_at_eoln)
438                                       else 0)
439
440        with self.push_indent(sublist_relative_indent):
441            self.write(sublist_open)
442            first = True
443            for lst in lists:
444                if first:
445                    first = False
446                else:
447                    if self.comma_at_eoln:
448                        self.write(sublist_sep)
449                        self.newline()
450                        self.write(sublist_open)
451                    else:
452                        self.newline()
453                        self.write(sublist_sep)
454                        self.write(' ')
455                        self.write(sublist_open)
456                self.print_list(lst, sep, relative_indent, standalone_items, are_names)
457                self.write(sublist_close)
458
459
460class IndentedStream(RawStream):
461    """Indented SQL parse tree writer.
462
463    :param int compact_lists_margin:
464           an integer value that, if given, is used to print lists on a single line, when they
465           do not exceed the given margin on the right
466    :param int split_string_literals_threshold:
467           an integer value that, if given, is used as the threshold beyond that a string
468           literal gets splitted in successive chunks of that length
469    :param \*\*options: other options accepted by :class:`RawStream`
470
471    This augments :class:`RawStream` to emit a prettified representation of a *parse tree*.
472    """
473
474    def __init__(self, compact_lists_margin=None, split_string_literals_threshold=None,
475                 **options):
476        super().__init__(**options)
477        self.compact_lists_margin = compact_lists_margin
478        self.split_string_literals_threshold = split_string_literals_threshold
479        self.current_indent = 0
480        self.indentation_stack = []
481
482    def dedent(self):
483        "Pop the indentation level from the stack and set `current_indent` to that."
484
485        self.current_indent = self.indentation_stack.pop()
486
487    def indent(self, amount=0, relative=True):
488        """Push current indentation level to the stack, then set it adding `amount` to the
489        `current_column` if `relative` is ``True`` otherwise to `current_indent`.
490        """
491
492        self.indentation_stack.append(self.current_indent)
493        base_indent = (self.current_column if relative else self.current_indent)
494        assert base_indent + amount >= 0
495        self.current_indent = base_indent + amount
496
497    @contextmanager
498    def push_indent(self, amount=0, relative=True):
499        """Create a context manager that calls :meth:`indent` and :meth:`dedent` around a block
500        of code.
501
502        This is just an helper to simplify code that adjust the indentation level:
503
504        .. code-block:: python
505
506          with output.push_indent(4):
507              # code that emits something with the new indentation
508        """
509
510        if self.pending_separator and relative:
511            amount += 1
512        if self.current_column == 0 and relative:
513            amount += self.current_indent
514        self.indent(amount, relative)
515        yield
516        self.dedent()
517
518    def newline(self):
519        "Emit a newline."
520
521        self.write('\n')
522
523    def space(self, count=1):
524        "Emit consecutive spaces."
525
526        self.write(' '*count)
527
528    def print_list(self, nodes, sep=',', relative_indent=None, standalone_items=None,
529                   are_names=False, is_symbol=False):
530        """Execute :meth:`print_node` on all the `nodes`, separating them with `sep`.
531
532        :param nodes: a sequence of :class:`~.node.Node` instances
533        :param str sep: the separator between them
534        :param bool relative_indent:
535               if given, the relative amount of indentation to apply before the first item, by
536               default computed automatically from the length of the separator `sep`
537        :param bool standalone_items: whether a newline will be emitted before each item
538        :param bool are_names:
539               whether the nodes are actually *names*, which possibly require to be enclosed
540               between double-quotes
541        :param bool is_symbol:
542               whether the nodes are actually an *operator name*, in which case the last one
543               must be printed verbatim (such as ``"MySchema".===``)
544        """
545
546        if standalone_items is None:
547            clm = self.compact_lists_margin
548            if clm is not None and clm > 0:
549                rawlist = self._concat_nodes(nodes, sep, are_names)
550                if self.current_column + len(rawlist) < clm:
551                    self.write(rawlist)
552                    return
553
554            standalone_items = not all(
555                (isinstance(n, Node)
556                 and n.node_tag in ('A_Const', 'ColumnRef', 'SetToDefault', 'RangeVar'))
557                for n in nodes)
558
559        if (((sep != ',' or not self.comma_at_eoln)
560             and len(nodes) > 1
561             and len(sep) > 1
562             and relative_indent is None
563             and not are_names
564             and not is_symbol
565             and standalone_items)):
566            self.write(' '*(len(sep) + 1))  # separator added automatically
567
568        super().print_list(nodes, sep, relative_indent, standalone_items, are_names, is_symbol)
569
570    def _write_quoted_string(self, s):
571        """Possibly split `s` string in successive chunks.
572
573        When the ``split_string_literals_threshold`` option is greater than 0 and the length of
574        `s` exceeds that value, split the string into multiple chunks.
575        """
576
577        sslt = self.split_string_literals_threshold
578        if sslt is None or sslt <= 0:
579            super()._write_quoted_string(s)
580        else:
581            multiline = '\n' in s
582            if multiline:
583                self.write('E')
584            with self.push_indent():
585                while True:
586                    chunk = s[:sslt]
587                    s = s[sslt:]
588                    # Avoid splitting on backslash
589                    while chunk.endswith("\\"):
590                        chunk += s[0]
591                        s = s[1:]
592                    chunk = chunk.replace("'", "''")
593                    if multiline:
594                        chunk = chunk.replace("\\", "\\\\")
595                        chunk = chunk.replace("\n", "\\n")
596                    self.write("'%s'" % chunk)
597                    if s:
598                        self.newline()
599                    else:
600                        break
601
602    def write(self, s):
603        """Write string `s` to the stream, adjusting the `current_column` accordingly.
604
605        :param str s: the string to emit
606        :return: the number of characters written to the stream
607
608        If `s` is a newline character (``\\n``) set `current_column` to 0. Otherwise when
609        `current_column` is 0 and `current_indent` is greater than 0 emit a number of
610        whitespaces *before* emitting `s`, to indent it as expected.
611        """
612
613        if s and s != '\n' and self.current_column == 0 and self.current_indent > 0:
614            self.current_column = super().write(' ' * self.current_indent)
615
616        count = super().write(s)
617        if s == '\n':
618            self.current_column = 0
619        else:
620            self.current_column += count
621
622        return count
623