1"""
2Custom YAML loading in Salt
3"""
4
5
6import warnings
7
8import salt.utils.stringutils
9import yaml  # pylint: disable=blacklisted-import
10from yaml.constructor import ConstructorError
11from yaml.nodes import MappingNode, SequenceNode
12
13# prefer C bindings over python when available
14BaseLoader = getattr(yaml, "CSafeLoader", yaml.SafeLoader)
15
16
17__all__ = ["SaltYamlSafeLoader", "load", "safe_load"]
18
19
20class DuplicateKeyWarning(RuntimeWarning):
21    """
22    Warned when duplicate keys exist
23    """
24
25
26warnings.simplefilter("always", category=DuplicateKeyWarning)
27
28
29# with code integrated from https://gist.github.com/844388
30class SaltYamlSafeLoader(BaseLoader):
31    """
32    Create a custom YAML loader that uses the custom constructor. This allows
33    for the YAML loading defaults to be manipulated based on needs within salt
34    to make things like sls file more intuitive.
35    """
36
37    def __init__(self, stream, dictclass=dict):
38        super().__init__(stream)
39        if dictclass is not dict:
40            # then assume ordered dict and use it for both !map and !omap
41            self.add_constructor("tag:yaml.org,2002:map", type(self).construct_yaml_map)
42            self.add_constructor(
43                "tag:yaml.org,2002:omap", type(self).construct_yaml_map
44            )
45        self.add_constructor("tag:yaml.org,2002:str", type(self).construct_yaml_str)
46        self.add_constructor(
47            "tag:yaml.org,2002:python/unicode", type(self).construct_unicode
48        )
49        self.add_constructor("tag:yaml.org,2002:timestamp", type(self).construct_scalar)
50        self.dictclass = dictclass
51
52    def construct_yaml_map(self, node):
53        data = self.dictclass()
54        yield data
55        value = self.construct_mapping(node)
56        data.update(value)
57
58    def construct_unicode(self, node):
59        return node.value
60
61    def construct_mapping(self, node, deep=False):
62        """
63        Build the mapping for YAML
64        """
65        if not isinstance(node, MappingNode):
66            raise ConstructorError(
67                None,
68                None,
69                "expected a mapping node, but found {}".format(node.id),
70                node.start_mark,
71            )
72
73        self.flatten_mapping(node)
74
75        context = "while constructing a mapping"
76        mapping = self.dictclass()
77        for key_node, value_node in node.value:
78            key = self.construct_object(key_node, deep=deep)
79            try:
80                hash(key)
81            except TypeError:
82                raise ConstructorError(
83                    context,
84                    node.start_mark,
85                    "found unacceptable key {}".format(key_node.value),
86                    key_node.start_mark,
87                )
88            value = self.construct_object(value_node, deep=deep)
89            if key in mapping:
90                raise ConstructorError(
91                    context,
92                    node.start_mark,
93                    "found conflicting ID '{}'".format(key),
94                    key_node.start_mark,
95                )
96            mapping[key] = value
97        return mapping
98
99    def construct_scalar(self, node):
100        """
101        Verify integers and pass them in correctly is they are declared
102        as octal
103        """
104        if node.tag == "tag:yaml.org,2002:int":
105            if node.value == "0":
106                pass
107            elif node.value.startswith("0") and not node.value.startswith(("0b", "0x")):
108                node.value = node.value.lstrip("0")
109                # If value was all zeros, node.value would have been reduced to
110                # an empty string. Change it to '0'.
111                if node.value == "":
112                    node.value = "0"
113        return super().construct_scalar(node)
114
115    def construct_yaml_str(self, node):
116        value = self.construct_scalar(node)
117        return salt.utils.stringutils.to_unicode(value)
118
119    def flatten_mapping(self, node):
120        merge = []
121        index = 0
122        while index < len(node.value):
123            key_node, value_node = node.value[index]
124
125            if key_node.tag == "tag:yaml.org,2002:merge":
126                del node.value[index]
127                if isinstance(value_node, MappingNode):
128                    self.flatten_mapping(value_node)
129                    merge.extend(value_node.value)
130                elif isinstance(value_node, SequenceNode):
131                    submerge = []
132                    for subnode in value_node.value:
133                        if not isinstance(subnode, MappingNode):
134                            raise ConstructorError(
135                                "while constructing a mapping",
136                                node.start_mark,
137                                "expected a mapping for merging, but found {}".format(
138                                    subnode.id
139                                ),
140                                subnode.start_mark,
141                            )
142                        self.flatten_mapping(subnode)
143                        submerge.append(subnode.value)
144                    submerge.reverse()
145                    for value in submerge:
146                        merge.extend(value)
147                else:
148                    raise ConstructorError(
149                        "while constructing a mapping",
150                        node.start_mark,
151                        "expected a mapping or list of mappings for merging, but"
152                        " found {}".format(value_node.id),
153                        value_node.start_mark,
154                    )
155            elif key_node.tag == "tag:yaml.org,2002:value":
156                key_node.tag = "tag:yaml.org,2002:str"
157                index += 1
158            else:
159                index += 1
160        if merge:
161            # Here we need to discard any duplicate entries based on key_node
162            existing_nodes = [name_node.value for name_node, value_node in node.value]
163            mergeable_items = [x for x in merge if x[0].value not in existing_nodes]
164
165            node.value = mergeable_items + node.value
166
167
168def load(stream, Loader=SaltYamlSafeLoader):
169    return yaml.load(stream, Loader=Loader)
170
171
172def safe_load(stream, Loader=SaltYamlSafeLoader):
173    """
174    .. versionadded:: 2018.3.0
175
176    Helper function which automagically uses our custom loader.
177    """
178    return yaml.load(stream, Loader=Loader)
179