1""" 2Custom YAML loading in Salt 3""" 4 5 6import warnings 7 8import salt.utils.stringutils 9import yaml # pylint: disable=blacklisted-import 10from yaml.constructor import ConstructorError 11from yaml.nodes import MappingNode, SequenceNode 12 13# prefer C bindings over python when available 14BaseLoader = getattr(yaml, "CSafeLoader", yaml.SafeLoader) 15 16 17__all__ = ["SaltYamlSafeLoader", "load", "safe_load"] 18 19 20class DuplicateKeyWarning(RuntimeWarning): 21 """ 22 Warned when duplicate keys exist 23 """ 24 25 26warnings.simplefilter("always", category=DuplicateKeyWarning) 27 28 29# with code integrated from https://gist.github.com/844388 30class SaltYamlSafeLoader(BaseLoader): 31 """ 32 Create a custom YAML loader that uses the custom constructor. This allows 33 for the YAML loading defaults to be manipulated based on needs within salt 34 to make things like sls file more intuitive. 35 """ 36 37 def __init__(self, stream, dictclass=dict): 38 super().__init__(stream) 39 if dictclass is not dict: 40 # then assume ordered dict and use it for both !map and !omap 41 self.add_constructor("tag:yaml.org,2002:map", type(self).construct_yaml_map) 42 self.add_constructor( 43 "tag:yaml.org,2002:omap", type(self).construct_yaml_map 44 ) 45 self.add_constructor("tag:yaml.org,2002:str", type(self).construct_yaml_str) 46 self.add_constructor( 47 "tag:yaml.org,2002:python/unicode", type(self).construct_unicode 48 ) 49 self.add_constructor("tag:yaml.org,2002:timestamp", type(self).construct_scalar) 50 self.dictclass = dictclass 51 52 def construct_yaml_map(self, node): 53 data = self.dictclass() 54 yield data 55 value = self.construct_mapping(node) 56 data.update(value) 57 58 def construct_unicode(self, node): 59 return node.value 60 61 def construct_mapping(self, node, deep=False): 62 """ 63 Build the mapping for YAML 64 """ 65 if not isinstance(node, MappingNode): 66 raise ConstructorError( 67 None, 68 None, 69 "expected a mapping node, but found {}".format(node.id), 70 node.start_mark, 71 ) 72 73 self.flatten_mapping(node) 74 75 context = "while constructing a mapping" 76 mapping = self.dictclass() 77 for key_node, value_node in node.value: 78 key = self.construct_object(key_node, deep=deep) 79 try: 80 hash(key) 81 except TypeError: 82 raise ConstructorError( 83 context, 84 node.start_mark, 85 "found unacceptable key {}".format(key_node.value), 86 key_node.start_mark, 87 ) 88 value = self.construct_object(value_node, deep=deep) 89 if key in mapping: 90 raise ConstructorError( 91 context, 92 node.start_mark, 93 "found conflicting ID '{}'".format(key), 94 key_node.start_mark, 95 ) 96 mapping[key] = value 97 return mapping 98 99 def construct_scalar(self, node): 100 """ 101 Verify integers and pass them in correctly is they are declared 102 as octal 103 """ 104 if node.tag == "tag:yaml.org,2002:int": 105 if node.value == "0": 106 pass 107 elif node.value.startswith("0") and not node.value.startswith(("0b", "0x")): 108 node.value = node.value.lstrip("0") 109 # If value was all zeros, node.value would have been reduced to 110 # an empty string. Change it to '0'. 111 if node.value == "": 112 node.value = "0" 113 return super().construct_scalar(node) 114 115 def construct_yaml_str(self, node): 116 value = self.construct_scalar(node) 117 return salt.utils.stringutils.to_unicode(value) 118 119 def flatten_mapping(self, node): 120 merge = [] 121 index = 0 122 while index < len(node.value): 123 key_node, value_node = node.value[index] 124 125 if key_node.tag == "tag:yaml.org,2002:merge": 126 del node.value[index] 127 if isinstance(value_node, MappingNode): 128 self.flatten_mapping(value_node) 129 merge.extend(value_node.value) 130 elif isinstance(value_node, SequenceNode): 131 submerge = [] 132 for subnode in value_node.value: 133 if not isinstance(subnode, MappingNode): 134 raise ConstructorError( 135 "while constructing a mapping", 136 node.start_mark, 137 "expected a mapping for merging, but found {}".format( 138 subnode.id 139 ), 140 subnode.start_mark, 141 ) 142 self.flatten_mapping(subnode) 143 submerge.append(subnode.value) 144 submerge.reverse() 145 for value in submerge: 146 merge.extend(value) 147 else: 148 raise ConstructorError( 149 "while constructing a mapping", 150 node.start_mark, 151 "expected a mapping or list of mappings for merging, but" 152 " found {}".format(value_node.id), 153 value_node.start_mark, 154 ) 155 elif key_node.tag == "tag:yaml.org,2002:value": 156 key_node.tag = "tag:yaml.org,2002:str" 157 index += 1 158 else: 159 index += 1 160 if merge: 161 # Here we need to discard any duplicate entries based on key_node 162 existing_nodes = [name_node.value for name_node, value_node in node.value] 163 mergeable_items = [x for x in merge if x[0].value not in existing_nodes] 164 165 node.value = mergeable_items + node.value 166 167 168def load(stream, Loader=SaltYamlSafeLoader): 169 return yaml.load(stream, Loader=Loader) 170 171 172def safe_load(stream, Loader=SaltYamlSafeLoader): 173 """ 174 .. versionadded:: 2018.3.0 175 176 Helper function which automagically uses our custom loader. 177 """ 178 return yaml.load(stream, Loader=Loader) 179