1# 2# Copyright (C) 2009-2020 the sqlparse authors and contributors 3# <see AUTHORS file> 4# 5# This module is part of python-sqlparse and is released under 6# the BSD License: https://opensource.org/licenses/BSD-3-Clause 7 8from sqlparse import sql, tokens as T 9from sqlparse.utils import offset, indent 10 11 12class AlignedIndentFilter: 13 join_words = (r'((LEFT\s+|RIGHT\s+|FULL\s+)?' 14 r'(INNER\s+|OUTER\s+|STRAIGHT\s+)?|' 15 r'(CROSS\s+|NATURAL\s+)?)?JOIN\b') 16 by_words = r'(GROUP|ORDER)\s+BY\b' 17 split_words = ('FROM', 18 join_words, 'ON', by_words, 19 'WHERE', 'AND', 'OR', 20 'HAVING', 'LIMIT', 21 'UNION', 'VALUES', 22 'SET', 'BETWEEN', 'EXCEPT') 23 24 def __init__(self, char=' ', n='\n'): 25 self.n = n 26 self.offset = 0 27 self.indent = 0 28 self.char = char 29 self._max_kwd_len = len('select') 30 31 def nl(self, offset=1): 32 # offset = 1 represent a single space after SELECT 33 offset = -len(offset) if not isinstance(offset, int) else offset 34 # add two for the space and parenthesis 35 indent = self.indent * (2 + self._max_kwd_len) 36 37 return sql.Token(T.Whitespace, self.n + self.char * ( 38 self._max_kwd_len + offset + indent + self.offset)) 39 40 def _process_statement(self, tlist): 41 if len(tlist.tokens) > 0 and tlist.tokens[0].is_whitespace \ 42 and self.indent == 0: 43 tlist.tokens.pop(0) 44 45 # process the main query body 46 self._process(sql.TokenList(tlist.tokens)) 47 48 def _process_parenthesis(self, tlist): 49 # if this isn't a subquery, don't re-indent 50 _, token = tlist.token_next_by(m=(T.DML, 'SELECT')) 51 if token is not None: 52 with indent(self): 53 tlist.insert_after(tlist[0], self.nl('SELECT')) 54 # process the inside of the parenthesis 55 self._process_default(tlist) 56 57 # de-indent last parenthesis 58 tlist.insert_before(tlist[-1], self.nl()) 59 60 def _process_identifierlist(self, tlist): 61 # columns being selected 62 identifiers = list(tlist.get_identifiers()) 63 identifiers.pop(0) 64 [tlist.insert_before(token, self.nl()) for token in identifiers] 65 self._process_default(tlist) 66 67 def _process_case(self, tlist): 68 offset_ = len('case ') + len('when ') 69 cases = tlist.get_cases(skip_ws=True) 70 # align the end as well 71 end_token = tlist.token_next_by(m=(T.Keyword, 'END'))[1] 72 cases.append((None, [end_token])) 73 74 condition_width = [len(' '.join(map(str, cond))) if cond else 0 75 for cond, _ in cases] 76 max_cond_width = max(condition_width) 77 78 for i, (cond, value) in enumerate(cases): 79 # cond is None when 'else or end' 80 stmt = cond[0] if cond else value[0] 81 82 if i > 0: 83 tlist.insert_before(stmt, self.nl(offset_ - len(str(stmt)))) 84 if cond: 85 ws = sql.Token(T.Whitespace, self.char * ( 86 max_cond_width - condition_width[i])) 87 tlist.insert_after(cond[-1], ws) 88 89 def _next_token(self, tlist, idx=-1): 90 split_words = T.Keyword, self.split_words, True 91 tidx, token = tlist.token_next_by(m=split_words, idx=idx) 92 # treat "BETWEEN x and y" as a single statement 93 if token and token.normalized == 'BETWEEN': 94 tidx, token = self._next_token(tlist, tidx) 95 if token and token.normalized == 'AND': 96 tidx, token = self._next_token(tlist, tidx) 97 return tidx, token 98 99 def _split_kwds(self, tlist): 100 tidx, token = self._next_token(tlist) 101 while token: 102 # joins, group/order by are special case. only consider the first 103 # word as aligner 104 if ( 105 token.match(T.Keyword, self.join_words, regex=True) 106 or token.match(T.Keyword, self.by_words, regex=True) 107 ): 108 token_indent = token.value.split()[0] 109 else: 110 token_indent = str(token) 111 tlist.insert_before(token, self.nl(token_indent)) 112 tidx += 1 113 tidx, token = self._next_token(tlist, tidx) 114 115 def _process_default(self, tlist): 116 self._split_kwds(tlist) 117 # process any sub-sub statements 118 for sgroup in tlist.get_sublists(): 119 idx = tlist.token_index(sgroup) 120 pidx, prev_ = tlist.token_prev(idx) 121 # HACK: make "group/order by" work. Longer than max_len. 122 offset_ = 3 if ( 123 prev_ and prev_.match(T.Keyword, self.by_words, regex=True) 124 ) else 0 125 with offset(self, offset_): 126 self._process(sgroup) 127 128 def _process(self, tlist): 129 func_name = '_process_{cls}'.format(cls=type(tlist).__name__) 130 func = getattr(self, func_name.lower(), self._process_default) 131 func(tlist) 132 133 def process(self, stmt): 134 self._process(stmt) 135 return stmt 136