1""" 2Custom YAML loading in Salt 3""" 4 5 6import re 7import warnings 8 9import salt.utils.stringutils 10import yaml # pylint: disable=blacklisted-import 11from yaml.constructor import ConstructorError 12from yaml.nodes import MappingNode, SequenceNode 13 14try: 15 yaml.Loader = yaml.CLoader 16 yaml.Dumper = yaml.CDumper 17except Exception: # pylint: disable=broad-except 18 pass 19 20 21__all__ = ["SaltYamlSafeLoader", "load", "safe_load"] 22 23 24class DuplicateKeyWarning(RuntimeWarning): 25 """ 26 Warned when duplicate keys exist 27 """ 28 29 30warnings.simplefilter("always", category=DuplicateKeyWarning) 31 32 33# with code integrated from https://gist.github.com/844388 34class SaltYamlSafeLoader(yaml.SafeLoader): 35 """ 36 Create a custom YAML loader that uses the custom constructor. This allows 37 for the YAML loading defaults to be manipulated based on needs within salt 38 to make things like sls file more intuitive. 39 """ 40 41 def __init__(self, stream, dictclass=dict): 42 super().__init__(stream) 43 if dictclass is not dict: 44 # then assume ordered dict and use it for both !map and !omap 45 self.add_constructor("tag:yaml.org,2002:map", type(self).construct_yaml_map) 46 self.add_constructor( 47 "tag:yaml.org,2002:omap", type(self).construct_yaml_map 48 ) 49 self.add_constructor("tag:yaml.org,2002:str", type(self).construct_yaml_str) 50 self.add_constructor( 51 "tag:yaml.org,2002:python/unicode", type(self).construct_unicode 52 ) 53 self.add_constructor("tag:yaml.org,2002:timestamp", type(self).construct_scalar) 54 self.dictclass = dictclass 55 56 def construct_yaml_map(self, node): 57 data = self.dictclass() 58 yield data 59 value = self.construct_mapping(node) 60 data.update(value) 61 62 def construct_unicode(self, node): 63 return node.value 64 65 def construct_mapping(self, node, deep=False): 66 """ 67 Build the mapping for YAML 68 """ 69 if not isinstance(node, MappingNode): 70 raise ConstructorError( 71 None, 72 None, 73 "expected a mapping node, but found {}".format(node.id), 74 node.start_mark, 75 ) 76 77 self.flatten_mapping(node) 78 79 context = "while constructing a mapping" 80 mapping = self.dictclass() 81 for key_node, value_node in node.value: 82 key = self.construct_object(key_node, deep=deep) 83 try: 84 hash(key) 85 except TypeError: 86 raise ConstructorError( 87 context, 88 node.start_mark, 89 "found unacceptable key {}".format(key_node.value), 90 key_node.start_mark, 91 ) 92 value = self.construct_object(value_node, deep=deep) 93 if key in mapping: 94 raise ConstructorError( 95 context, 96 node.start_mark, 97 "found conflicting ID '{}'".format(key), 98 key_node.start_mark, 99 ) 100 mapping[key] = value 101 return mapping 102 103 def construct_scalar(self, node): 104 """ 105 Verify integers and pass them in correctly is they are declared 106 as octal 107 """ 108 if node.tag == "tag:yaml.org,2002:int": 109 if node.value == "0": 110 pass 111 elif node.value.startswith("0") and not node.value.startswith(("0b", "0x")): 112 node.value = node.value.lstrip("0") 113 # If value was all zeros, node.value would have been reduced to 114 # an empty string. Change it to '0'. 115 if node.value == "": 116 node.value = "0" 117 elif node.tag == "tag:yaml.org,2002:str": 118 # If any string comes in as a quoted unicode literal, eval it into 119 # the proper unicode string type. 120 if re.match(r'^u([\'"]).+\1$', node.value, flags=re.IGNORECASE): 121 node.value = eval(node.value, {}, {}) # pylint: disable=W0123 122 return super().construct_scalar(node) 123 124 def construct_yaml_str(self, node): 125 value = self.construct_scalar(node) 126 return salt.utils.stringutils.to_unicode(value) 127 128 def fetch_plain(self): 129 """ 130 Handle unicode literal strings which appear inline in the YAML 131 """ 132 orig_line = self.line 133 orig_column = self.column 134 orig_pointer = self.pointer 135 try: 136 return super().fetch_plain() 137 except yaml.scanner.ScannerError as exc: 138 problem_line = self.line 139 problem_column = self.column 140 problem_pointer = self.pointer 141 if exc.problem == "found unexpected ':'": 142 # Reset to prior position 143 self.line = orig_line 144 self.column = orig_column 145 self.pointer = orig_pointer 146 if self.peek(0) == "u": 147 # Might be a unicode literal string, check for 2nd char and 148 # call the appropriate fetch func if it's a quote 149 quote_char = self.peek(1) 150 if quote_char in ("'", '"'): 151 # Skip the "u" prefix by advancing the column and 152 # pointer by 1 153 self.column += 1 154 self.pointer += 1 155 if quote_char == "'": 156 return self.fetch_single() 157 else: 158 return self.fetch_double() 159 else: 160 # This wasn't a unicode literal string, so the caught 161 # exception was correct. Restore the old position and 162 # then raise the caught exception. 163 self.line = problem_line 164 self.column = problem_column 165 self.pointer = problem_pointer 166 # Raise the caught exception 167 raise exc 168 169 def flatten_mapping(self, node): 170 merge = [] 171 index = 0 172 while index < len(node.value): 173 key_node, value_node = node.value[index] 174 175 if key_node.tag == "tag:yaml.org,2002:merge": 176 del node.value[index] 177 if isinstance(value_node, MappingNode): 178 self.flatten_mapping(value_node) 179 merge.extend(value_node.value) 180 elif isinstance(value_node, SequenceNode): 181 submerge = [] 182 for subnode in value_node.value: 183 if not isinstance(subnode, MappingNode): 184 raise ConstructorError( 185 "while constructing a mapping", 186 node.start_mark, 187 "expected a mapping for merging, but found {}".format( 188 subnode.id 189 ), 190 subnode.start_mark, 191 ) 192 self.flatten_mapping(subnode) 193 submerge.append(subnode.value) 194 submerge.reverse() 195 for value in submerge: 196 merge.extend(value) 197 else: 198 raise ConstructorError( 199 "while constructing a mapping", 200 node.start_mark, 201 "expected a mapping or list of mappings for merging, but" 202 " found {}".format(value_node.id), 203 value_node.start_mark, 204 ) 205 elif key_node.tag == "tag:yaml.org,2002:value": 206 key_node.tag = "tag:yaml.org,2002:str" 207 index += 1 208 else: 209 index += 1 210 if merge: 211 # Here we need to discard any duplicate entries based on key_node 212 existing_nodes = [name_node.value for name_node, value_node in node.value] 213 mergeable_items = [x for x in merge if x[0].value not in existing_nodes] 214 215 node.value = mergeable_items + node.value 216 217 218def load(stream, Loader=SaltYamlSafeLoader): 219 return yaml.load(stream, Loader=Loader) 220 221 222def safe_load(stream, Loader=SaltYamlSafeLoader): 223 """ 224 .. versionadded:: 2018.3.0 225 226 Helper function which automagically uses our custom loader. 227 """ 228 return yaml.load(stream, Loader=Loader) 229