1# -*- coding: utf-8 -*- 2# This code is part of Ansible, but is an independent component. 3# This particular file snippet, and this file snippet only, is BSD licensed. 4# Modules you write using this snippet, which is embedded dynamically by Ansible 5# still belong to the author of the module, and may assign their own license 6# to the complete work. 7# 8# Copyright (c) 2014, Toshio Kuratomi <tkuratomi@ansible.com> 9# 10# Simplified BSD License (see licenses/simplified_bsd.txt or https://opensource.org/licenses/BSD-2-Clause) 11 12from __future__ import (absolute_import, division, print_function) 13__metaclass__ = type 14 15import re 16 17 18# Input patterns for is_input_dangerous function: 19# 20# 1. '"' in string and '--' in string or 21# "'" in string and '--' in string 22PATTERN_1 = re.compile(r'(\'|\").*--') 23 24# 2. union \ intersect \ except + select 25PATTERN_2 = re.compile(r'(UNION|INTERSECT|EXCEPT).*SELECT', re.IGNORECASE) 26 27# 3. ';' and any KEY_WORDS 28PATTERN_3 = re.compile(r';.*(SELECT|UPDATE|INSERT|DELETE|DROP|TRUNCATE|ALTER)', re.IGNORECASE) 29 30 31class SQLParseError(Exception): 32 pass 33 34 35class UnclosedQuoteError(SQLParseError): 36 pass 37 38 39# maps a type of identifier to the maximum number of dot levels that are 40# allowed to specify that identifier. For example, a database column can be 41# specified by up to 4 levels: database.schema.table.column 42_PG_IDENTIFIER_TO_DOT_LEVEL = dict( 43 database=1, 44 schema=2, 45 table=3, 46 column=4, 47 role=1, 48 tablespace=1, 49 sequence=3, 50 publication=1, 51) 52_MYSQL_IDENTIFIER_TO_DOT_LEVEL = dict(database=1, table=2, column=3, role=1, vars=1) 53 54 55def _find_end_quote(identifier, quote_char): 56 accumulate = 0 57 while True: 58 try: 59 quote = identifier.index(quote_char) 60 except ValueError: 61 raise UnclosedQuoteError 62 accumulate = accumulate + quote 63 try: 64 next_char = identifier[quote + 1] 65 except IndexError: 66 return accumulate 67 if next_char == quote_char: 68 try: 69 identifier = identifier[quote + 2:] 70 accumulate = accumulate + 2 71 except IndexError: 72 raise UnclosedQuoteError 73 else: 74 return accumulate 75 76 77def _identifier_parse(identifier, quote_char): 78 if not identifier: 79 raise SQLParseError('Identifier name unspecified or unquoted trailing dot') 80 81 already_quoted = False 82 if identifier.startswith(quote_char): 83 already_quoted = True 84 try: 85 end_quote = _find_end_quote(identifier[1:], quote_char=quote_char) + 1 86 except UnclosedQuoteError: 87 already_quoted = False 88 else: 89 if end_quote < len(identifier) - 1: 90 if identifier[end_quote + 1] == '.': 91 dot = end_quote + 1 92 first_identifier = identifier[:dot] 93 next_identifier = identifier[dot + 1:] 94 further_identifiers = _identifier_parse(next_identifier, quote_char) 95 further_identifiers.insert(0, first_identifier) 96 else: 97 raise SQLParseError('User escaped identifiers must escape extra quotes') 98 else: 99 further_identifiers = [identifier] 100 101 if not already_quoted: 102 try: 103 dot = identifier.index('.') 104 except ValueError: 105 identifier = identifier.replace(quote_char, quote_char * 2) 106 identifier = ''.join((quote_char, identifier, quote_char)) 107 further_identifiers = [identifier] 108 else: 109 if dot == 0 or dot >= len(identifier) - 1: 110 identifier = identifier.replace(quote_char, quote_char * 2) 111 identifier = ''.join((quote_char, identifier, quote_char)) 112 further_identifiers = [identifier] 113 else: 114 first_identifier = identifier[:dot] 115 next_identifier = identifier[dot + 1:] 116 further_identifiers = _identifier_parse(next_identifier, quote_char) 117 first_identifier = first_identifier.replace(quote_char, quote_char * 2) 118 first_identifier = ''.join((quote_char, first_identifier, quote_char)) 119 further_identifiers.insert(0, first_identifier) 120 121 return further_identifiers 122 123 124def pg_quote_identifier(identifier, id_type): 125 identifier_fragments = _identifier_parse(identifier, quote_char='"') 126 if len(identifier_fragments) > _PG_IDENTIFIER_TO_DOT_LEVEL[id_type]: 127 raise SQLParseError('PostgreSQL does not support %s with more than %i dots' % (id_type, _PG_IDENTIFIER_TO_DOT_LEVEL[id_type])) 128 return '.'.join(identifier_fragments) 129 130 131def mysql_quote_identifier(identifier, id_type): 132 identifier_fragments = _identifier_parse(identifier, quote_char='`') 133 if (len(identifier_fragments) - 1) > _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type]: 134 raise SQLParseError('MySQL does not support %s with more than %i dots' % (id_type, _MYSQL_IDENTIFIER_TO_DOT_LEVEL[id_type])) 135 136 special_cased_fragments = [] 137 for fragment in identifier_fragments: 138 if fragment == '`*`': 139 special_cased_fragments.append('*') 140 else: 141 special_cased_fragments.append(fragment) 142 143 return '.'.join(special_cased_fragments) 144 145 146def is_input_dangerous(string): 147 """Check if the passed string is potentially dangerous. 148 Can be used to prevent SQL injections. 149 150 Note: use this function only when you can't use 151 psycopg2's cursor.execute method parametrized 152 (typically with DDL queries). 153 """ 154 if not string: 155 return False 156 157 for pattern in (PATTERN_1, PATTERN_2, PATTERN_3): 158 if re.search(pattern, string): 159 return True 160 161 return False 162 163 164def check_input(module, *args): 165 """Wrapper for is_input_dangerous function.""" 166 needs_to_check = args 167 168 dangerous_elements = [] 169 170 for elem in needs_to_check: 171 if isinstance(elem, str): 172 if is_input_dangerous(elem): 173 dangerous_elements.append(elem) 174 175 elif isinstance(elem, list): 176 for e in elem: 177 if is_input_dangerous(e): 178 dangerous_elements.append(e) 179 180 elif elem is None or isinstance(elem, bool): 181 pass 182 183 else: 184 elem = str(elem) 185 if is_input_dangerous(elem): 186 dangerous_elements.append(elem) 187 188 if dangerous_elements: 189 module.fail_json(msg="Passed input '%s' is " 190 "potentially dangerous" % ', '.join(dangerous_elements)) 191