1#
2# Copyright (C) 2009-2020 the sqlparse authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of python-sqlparse and is released under
6# the BSD License: https://opensource.org/licenses/BSD-3-Clause
7
8from sqlparse import sql, tokens as T
9from sqlparse.utils import offset, indent
10
11
12class AlignedIndentFilter:
13    join_words = (r'((LEFT\s+|RIGHT\s+|FULL\s+)?'
14                  r'(INNER\s+|OUTER\s+|STRAIGHT\s+)?|'
15                  r'(CROSS\s+|NATURAL\s+)?)?JOIN\b')
16    by_words = r'(GROUP|ORDER)\s+BY\b'
17    split_words = ('FROM',
18                   join_words, 'ON', by_words,
19                   'WHERE', 'AND', 'OR',
20                   'HAVING', 'LIMIT',
21                   'UNION', 'VALUES',
22                   'SET', 'BETWEEN', 'EXCEPT')
23
24    def __init__(self, char=' ', n='\n'):
25        self.n = n
26        self.offset = 0
27        self.indent = 0
28        self.char = char
29        self._max_kwd_len = len('select')
30
31    def nl(self, offset=1):
32        # offset = 1 represent a single space after SELECT
33        offset = -len(offset) if not isinstance(offset, int) else offset
34        # add two for the space and parenthesis
35        indent = self.indent * (2 + self._max_kwd_len)
36
37        return sql.Token(T.Whitespace, self.n + self.char * (
38            self._max_kwd_len + offset + indent + self.offset))
39
40    def _process_statement(self, tlist):
41        if len(tlist.tokens) > 0 and tlist.tokens[0].is_whitespace \
42                and self.indent == 0:
43            tlist.tokens.pop(0)
44
45        # process the main query body
46        self._process(sql.TokenList(tlist.tokens))
47
48    def _process_parenthesis(self, tlist):
49        # if this isn't a subquery, don't re-indent
50        _, token = tlist.token_next_by(m=(T.DML, 'SELECT'))
51        if token is not None:
52            with indent(self):
53                tlist.insert_after(tlist[0], self.nl('SELECT'))
54                # process the inside of the parenthesis
55                self._process_default(tlist)
56
57            # de-indent last parenthesis
58            tlist.insert_before(tlist[-1], self.nl())
59
60    def _process_identifierlist(self, tlist):
61        # columns being selected
62        identifiers = list(tlist.get_identifiers())
63        identifiers.pop(0)
64        [tlist.insert_before(token, self.nl()) for token in identifiers]
65        self._process_default(tlist)
66
67    def _process_case(self, tlist):
68        offset_ = len('case ') + len('when ')
69        cases = tlist.get_cases(skip_ws=True)
70        # align the end as well
71        end_token = tlist.token_next_by(m=(T.Keyword, 'END'))[1]
72        cases.append((None, [end_token]))
73
74        condition_width = [len(' '.join(map(str, cond))) if cond else 0
75                           for cond, _ in cases]
76        max_cond_width = max(condition_width)
77
78        for i, (cond, value) in enumerate(cases):
79            # cond is None when 'else or end'
80            stmt = cond[0] if cond else value[0]
81
82            if i > 0:
83                tlist.insert_before(stmt, self.nl(offset_ - len(str(stmt))))
84            if cond:
85                ws = sql.Token(T.Whitespace, self.char * (
86                    max_cond_width - condition_width[i]))
87                tlist.insert_after(cond[-1], ws)
88
89    def _next_token(self, tlist, idx=-1):
90        split_words = T.Keyword, self.split_words, True
91        tidx, token = tlist.token_next_by(m=split_words, idx=idx)
92        # treat "BETWEEN x and y" as a single statement
93        if token and token.normalized == 'BETWEEN':
94            tidx, token = self._next_token(tlist, tidx)
95            if token and token.normalized == 'AND':
96                tidx, token = self._next_token(tlist, tidx)
97        return tidx, token
98
99    def _split_kwds(self, tlist):
100        tidx, token = self._next_token(tlist)
101        while token:
102            # joins, group/order by are special case. only consider the first
103            # word as aligner
104            if (
105                token.match(T.Keyword, self.join_words, regex=True)
106                or token.match(T.Keyword, self.by_words, regex=True)
107            ):
108                token_indent = token.value.split()[0]
109            else:
110                token_indent = str(token)
111            tlist.insert_before(token, self.nl(token_indent))
112            tidx += 1
113            tidx, token = self._next_token(tlist, tidx)
114
115    def _process_default(self, tlist):
116        self._split_kwds(tlist)
117        # process any sub-sub statements
118        for sgroup in tlist.get_sublists():
119            idx = tlist.token_index(sgroup)
120            pidx, prev_ = tlist.token_prev(idx)
121            # HACK: make "group/order by" work. Longer than max_len.
122            offset_ = 3 if (
123                prev_ and prev_.match(T.Keyword, self.by_words, regex=True)
124            ) else 0
125            with offset(self, offset_):
126                self._process(sgroup)
127
128    def _process(self, tlist):
129        func_name = '_process_{cls}'.format(cls=type(tlist).__name__)
130        func = getattr(self, func_name.lower(), self._process_default)
131        func(tlist)
132
133    def process(self, stmt):
134        self._process(stmt)
135        return stmt
136