1import re
2import sqlparse
3from sqlparse.sql import IdentifierList, Identifier, Function
4from sqlparse.tokens import Keyword, DML, Punctuation
5
6cleanup_regex = {
7        # This matches only alphanumerics and underscores.
8        'alphanum_underscore': re.compile(r'(\w+)$'),
9        # This matches everything except spaces, parens, colon, and comma
10        'many_punctuations': re.compile(r'([^():,\s]+)$'),
11        # This matches everything except spaces, parens, colon, comma, and period
12        'most_punctuations': re.compile(r'([^\.():,\s]+)$'),
13        # This matches everything except a space.
14        'all_punctuations': re.compile(r'([^\s]+)$'),
15        }
16
17def last_word(text, include='alphanum_underscore'):
18    r"""
19    Find the last word in a sentence.
20
21    >>> last_word('abc')
22    'abc'
23    >>> last_word(' abc')
24    'abc'
25    >>> last_word('')
26    ''
27    >>> last_word(' ')
28    ''
29    >>> last_word('abc ')
30    ''
31    >>> last_word('abc def')
32    'def'
33    >>> last_word('abc def ')
34    ''
35    >>> last_word('abc def;')
36    ''
37    >>> last_word('bac $def')
38    'def'
39    >>> last_word('bac $def', include='most_punctuations')
40    '$def'
41    >>> last_word('bac \def', include='most_punctuations')
42    '\\\\def'
43    >>> last_word('bac \def;', include='most_punctuations')
44    '\\\\def;'
45    >>> last_word('bac::def', include='most_punctuations')
46    'def'
47    """
48
49    if not text:   # Empty string
50        return ''
51
52    if text[-1].isspace():
53        return ''
54    else:
55        regex = cleanup_regex[include]
56        matches = regex.search(text)
57        if matches:
58            return matches.group(0)
59        else:
60            return ''
61
62
63# This code is borrowed from sqlparse example script.
64# <url>
65def is_subselect(parsed):
66    if not parsed.is_group:
67        return False
68    for item in parsed.tokens:
69        if item.ttype is DML and item.value.upper() in ('SELECT', 'INSERT',
70                'UPDATE', 'CREATE', 'DELETE'):
71            return True
72    return False
73
74def extract_from_part(parsed, stop_at_punctuation=True):
75    tbl_prefix_seen = False
76    for item in parsed.tokens:
77        if tbl_prefix_seen:
78            if is_subselect(item):
79                for x in extract_from_part(item, stop_at_punctuation):
80                    yield x
81            elif stop_at_punctuation and item.ttype is Punctuation:
82                return
83            # An incomplete nested select won't be recognized correctly as a
84            # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
85            # the second FROM to trigger this elif condition resulting in a
86            # StopIteration. So we need to ignore the keyword if the keyword
87            # FROM.
88            # Also 'SELECT * FROM abc JOIN def' will trigger this elif
89            # condition. So we need to ignore the keyword JOIN and its variants
90            # INNER JOIN, FULL OUTER JOIN, etc.
91            elif item.ttype is Keyword and (
92                    not item.value.upper() == 'FROM') and (
93                    not item.value.upper().endswith('JOIN')):
94                return
95            else:
96                yield item
97        elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and
98                item.value.upper() in ('COPY', 'FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)):
99            tbl_prefix_seen = True
100        # 'SELECT a, FROM abc' will detect FROM as part of the column list.
101        # So this check here is necessary.
102        elif isinstance(item, IdentifierList):
103            for identifier in item.get_identifiers():
104                if (identifier.ttype is Keyword and
105                        identifier.value.upper() == 'FROM'):
106                    tbl_prefix_seen = True
107                    break
108
109def extract_table_identifiers(token_stream):
110    """yields tuples of (schema_name, table_name, table_alias)"""
111
112    for item in token_stream:
113        if isinstance(item, IdentifierList):
114            for identifier in item.get_identifiers():
115                # Sometimes Keywords (such as FROM ) are classified as
116                # identifiers which don't have the get_real_name() method.
117                try:
118                    schema_name = identifier.get_parent_name()
119                    real_name = identifier.get_real_name()
120                except AttributeError:
121                    continue
122                if real_name:
123                    yield (schema_name, real_name, identifier.get_alias())
124        elif isinstance(item, Identifier):
125            real_name = item.get_real_name()
126            schema_name = item.get_parent_name()
127
128            if real_name:
129                yield (schema_name, real_name, item.get_alias())
130            else:
131                name = item.get_name()
132                yield (None, name, item.get_alias() or name)
133        elif isinstance(item, Function):
134            yield (None, item.get_name(), item.get_name())
135
136# extract_tables is inspired from examples in the sqlparse lib.
137def extract_tables(sql):
138    """Extract the table names from an SQL statment.
139
140    Returns a list of (schema, table, alias) tuples
141
142    """
143    parsed = sqlparse.parse(sql)
144    if not parsed:
145        return []
146
147    # INSERT statements must stop looking for tables at the sign of first
148    # Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
149    # abc is the table name, but if we don't stop at the first lparen, then
150    # we'll identify abc, col1 and col2 as table names.
151    insert_stmt = parsed[0].token_first().value.lower() == 'insert'
152    stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
153    return list(extract_table_identifiers(stream))
154
155def find_prev_keyword(sql):
156    """ Find the last sql keyword in an SQL statement
157
158    Returns the value of the last keyword, and the text of the query with
159    everything after the last keyword stripped
160    """
161    if not sql.strip():
162        return None, ''
163
164    parsed = sqlparse.parse(sql)[0]
165    flattened = list(parsed.flatten())
166
167    logical_operators = ('AND', 'OR', 'NOT', 'BETWEEN')
168
169    for t in reversed(flattened):
170        if t.value == '(' or (t.is_keyword and (
171                              t.value.upper() not in logical_operators)):
172            # Find the location of token t in the original parsed statement
173            # We can't use parsed.token_index(t) because t may be a child token
174            # inside a TokenList, in which case token_index thows an error
175            # Minimal example:
176            #   p = sqlparse.parse('select * from foo where bar')
177            #   t = list(p.flatten())[-3]  # The "Where" token
178            #   p.token_index(t)  # Throws ValueError: not in list
179            idx = flattened.index(t)
180
181            # Combine the string values of all tokens in the original list
182            # up to and including the target keyword token t, to produce a
183            # query string with everything after the keyword token removed
184            text = ''.join(tok.value for tok in flattened[:idx+1])
185            return t, text
186
187    return None, ''
188
189
190def query_starts_with(query, prefixes):
191    """Check if the query starts with any item from *prefixes*."""
192    prefixes = [prefix.lower() for prefix in prefixes]
193    formatted_sql = sqlparse.format(query.lower(), strip_comments=True)
194    return bool(formatted_sql) and formatted_sql.split()[0] in prefixes
195
196
197def queries_start_with(queries, prefixes):
198    """Check if any queries start with any item from *prefixes*."""
199    for query in sqlparse.split(queries):
200        if query and query_starts_with(query, prefixes) is True:
201            return True
202    return False
203
204
205def query_has_where_clause(query):
206    """Check if the query contains a where-clause."""
207    return any(
208        isinstance(token, sqlparse.sql.Where)
209        for token_list in sqlparse.parse(query)
210        for token in token_list
211    )
212
213
214def is_destructive(queries):
215    """Returns if any of the queries in *queries* is destructive."""
216    keywords = ('drop', 'shutdown', 'delete', 'truncate', 'alter')
217    for query in sqlparse.split(queries):
218        if query:
219            if query_starts_with(query, keywords) is True:
220                return True
221            elif query_starts_with(
222                query, ['update']
223            ) is True and not query_has_where_clause(query):
224                return True
225
226    return False
227
228
229def is_open_quote(sql):
230    """Returns true if the query contains an unclosed quote."""
231
232    # parsed can contain one or more semi-colon separated commands
233    parsed = sqlparse.parse(sql)
234    return any(_parsed_is_open_quote(p) for p in parsed)
235
236
237if __name__ == '__main__':
238    sql = 'select * from (select t. from tabl t'
239    print (extract_tables(sql))
240
241
242def is_dropping_database(queries, dbname):
243    """Determine if the query is dropping a specific database."""
244    result = False
245    if dbname is None:
246        return False
247
248    def normalize_db_name(db):
249        return db.lower().strip('`"')
250
251    dbname = normalize_db_name(dbname)
252
253    for query in sqlparse.parse(queries):
254        keywords = [t for t in query.tokens if t.is_keyword]
255        if len(keywords) < 2:
256            continue
257        if keywords[0].normalized in ("DROP", "CREATE") and keywords[1].value.lower() in (
258            "database",
259            "schema",
260        ):
261            database_token = next(
262                (t for t in query.tokens if isinstance(t, Identifier)), None
263            )
264            if database_token is not None and normalize_db_name(database_token.get_name()) == dbname:
265                result = keywords[0].normalized == "DROP"
266    else:
267        return result
268