1# -*- coding: utf-8 -*-
2# This code is part of Ansible, but is an independent component.
3# This particular file snippet, and this file snippet only, is BSD licensed.
4# Modules you write using this snippet, which is embedded dynamically by Ansible
5# still belong to the author of the module, and may assign their own license
6# to the complete work.
7#
8# Copyright (c) 2014, Toshio Kuratomi <tkuratomi@ansible.com>
9#
10# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause)
11
12from __future__ import (absolute_import, division, print_function)
13__metaclass__ = type
14
15import re
16
17
18# Input patterns for is_input_dangerous function:
19#
20# 1. '"' in string and '--' in string or
21# "'" in string and '--' in string
22PATTERN_1 = re.compile(r'(\'|\").*--')
23
24# 2. union \ intersect \ except + select
25PATTERN_2 = re.compile(r'(UNION|INTERSECT|EXCEPT).*SELECT', re.IGNORECASE)
26
27# 3. ';' and any KEY_WORDS
28PATTERN_3 = re.compile(r';.*(SELECT|UPDATE|INSERT|DELETE|DROP|TRUNCATE|ALTER)', re.IGNORECASE)
29
30
31class SQLParseError(Exception):
32    pass
33
34
35class UnclosedQuoteError(SQLParseError):
36    pass
37
38
39# maps a type of identifier to the maximum number of dot levels that are
40# allowed to specify that identifier.  For example, a database column can be
41# specified by up to 4 levels: database.schema.table.column
42_PG_IDENTIFIER_TO_DOT_LEVEL = dict(
43    database=1,
44    schema=2,
45    table=3,
46    column=4,
47    role=1,
48    tablespace=1,
49    sequence=3,
50    publication=1,
51)
52_MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1, vars=1)
53
54
55def _find_end_quote(identifier, quote_char):
56    accumulate = 0
57    while True:
58        try:
59            quote = identifier.index(quote_char)
60        except ValueError:
61            raise UnclosedQuoteError
62        accumulate = accumulate + quote
63        try:
64            next_char = identifier[quote + 1]
65        except IndexError:
66            return accumulate
67        if next_char == quote_char:
68            try:
69                identifier = identifier[quote + 2:]
70                accumulate = accumulate + 2
71            except IndexError:
72                raise UnclosedQuoteError
73        else:
74            return accumulate
75
76
77def _identifier_parse(identifier, quote_char):
78    if not identifier:
79        raise SQLParseError('Identifier name unspecified or unquoted trailing dot')
80
81    already_quoted = False
82    if identifier.startswith(quote_char):
83        already_quoted = True
84        try:
85            end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1
86        except UnclosedQuoteError:
87            already_quoted = False
88        else:
89            if end_quote < len(identifier) - 1:
90                if identifier[end_quote + 1] == '.':
91                    dot = end_quote + 1
92                    first_identifier = identifier[:dot]
93                    next_identifier = identifier[dot + 1:]
94                    further_identifiers = _identifier_parse(next_identifier, quote_char)
95                    further_identifiers.insert(0, first_identifier)
96                else:
97                    raise SQLParseError('User escaped identifiers must escape extra quotes')
98            else:
99                further_identifiers = [identifier]
100
101    if not already_quoted:
102        try:
103            dot = identifier.index('.')
104        except ValueError:
105            identifier = identifier.replace(quote_char, quote_char * 2)
106            identifier = ''.join((quote_char, identifier, quote_char))
107            further_identifiers = [identifier]
108        else:
109            if dot == 0 or dot >= len(identifier) - 1:
110                identifier = identifier.replace(quote_char, quote_char * 2)
111                identifier = ''.join((quote_char, identifier, quote_char))
112                further_identifiers = [identifier]
113            else:
114                first_identifier = identifier[:dot]
115                next_identifier = identifier[dot + 1:]
116                further_identifiers = _identifier_parse(next_identifier, quote_char)
117                first_identifier = first_identifier.replace(quote_char, quote_char * 2)
118                first_identifier = ''.join((quote_char, first_identifier, quote_char))
119                further_identifiers.insert(0, first_identifier)
120
121    return further_identifiers
122
123
124def pg_quote_identifier(identifier, id_type):
125    identifier_fragments = _identifier_parse(identifier, quote_char='"')
126    if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]:
127        raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]))
128    return '.'.join(identifier_fragments)
129
130
131def mysql_quote_identifier(identifier, id_type):
132    identifier_fragments = _identifier_parse(identifier, quote_char='`')
133    if (len(identifier_fragments) - 1) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]:
134        raise SQLParseError('MySQL does not support %s with more than %i dots' % (id_type, _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]))
135
136    special_cased_fragments = []
137    for fragment in identifier_fragments:
138        if fragment == '`*`':
139            special_cased_fragments.append('*')
140        else:
141            special_cased_fragments.append(fragment)
142
143    return '.'.join(special_cased_fragments)
144
145
146def is_input_dangerous(string):
147    """Check if the passed string is potentially dangerous.
148    Can be used to prevent SQL injections.
149
150    Note: use this function only when you can't use
151      psycopg2's cursor.execute method parametrized
152      (typically with DDL queries).
153    """
154    if not string:
155        return False
156
157    for pattern in (PATTERN_1, PATTERN_2, PATTERN_3):
158        if re.search(pattern, string):
159            return True
160
161    return False
162
163
164def check_input(module, *args):
165    """Wrapper for is_input_dangerous function."""
166    needs_to_check = args
167
168    dangerous_elements = []
169
170    for elem in needs_to_check:
171        if isinstance(elem, str):
172            if is_input_dangerous(elem):
173                dangerous_elements.append(elem)
174
175        elif isinstance(elem, list):
176            for e in elem:
177                if is_input_dangerous(e):
178                    dangerous_elements.append(e)
179
180        elif elem is None or isinstance(elem, bool):
181            pass
182
183        else:
184            elem = str(elem)
185            if is_input_dangerous(elem):
186                dangerous_elements.append(elem)
187
188    if dangerous_elements:
189        module.fail_json(msg="Passed input '%s' is "
190                             "potentially dangerous" % ', '.join(dangerous_elements))
191