1"""
2Custom YAML loading in Salt
3"""
4
5
6import re
7import warnings
8
9import salt.utils.stringutils
10import yaml  # pylint: disable=blacklisted-import
11from yaml.constructor import ConstructorError
12from yaml.nodes import MappingNode, SequenceNode
13
14try:
15    yaml.Loader = yaml.CLoader
16    yaml.Dumper = yaml.CDumper
17except Exception:  # pylint: disable=broad-except
18    pass
19
20
21__all__ = ["SaltYamlSafeLoader", "load", "safe_load"]
22
23
24class DuplicateKeyWarning(RuntimeWarning):
25    """
26    Warned when duplicate keys exist
27    """
28
29
30warnings.simplefilter("always", category=DuplicateKeyWarning)
31
32
33# with code integrated from https://gist.github.com/844388
34class SaltYamlSafeLoader(yaml.SafeLoader):
35    """
36    Create a custom YAML loader that uses the custom constructor. This allows
37    for the YAML loading defaults to be manipulated based on needs within salt
38    to make things like sls file more intuitive.
39    """
40
41    def __init__(self, stream, dictclass=dict):
42        super().__init__(stream)
43        if dictclass is not dict:
44            # then assume ordered dict and use it for both !map and !omap
45            self.add_constructor("tag:yaml.org,2002:map", type(self).construct_yaml_map)
46            self.add_constructor(
47                "tag:yaml.org,2002:omap", type(self).construct_yaml_map
48            )
49        self.add_constructor("tag:yaml.org,2002:str", type(self).construct_yaml_str)
50        self.add_constructor(
51            "tag:yaml.org,2002:python/unicode", type(self).construct_unicode
52        )
53        self.add_constructor("tag:yaml.org,2002:timestamp", type(self).construct_scalar)
54        self.dictclass = dictclass
55
56    def construct_yaml_map(self, node):
57        data = self.dictclass()
58        yield data
59        value = self.construct_mapping(node)
60        data.update(value)
61
62    def construct_unicode(self, node):
63        return node.value
64
65    def construct_mapping(self, node, deep=False):
66        """
67        Build the mapping for YAML
68        """
69        if not isinstance(node, MappingNode):
70            raise ConstructorError(
71                None,
72                None,
73                "expected a mapping node, but found {}".format(node.id),
74                node.start_mark,
75            )
76
77        self.flatten_mapping(node)
78
79        context = "while constructing a mapping"
80        mapping = self.dictclass()
81        for key_node, value_node in node.value:
82            key = self.construct_object(key_node, deep=deep)
83            try:
84                hash(key)
85            except TypeError:
86                raise ConstructorError(
87                    context,
88                    node.start_mark,
89                    "found unacceptable key {}".format(key_node.value),
90                    key_node.start_mark,
91                )
92            value = self.construct_object(value_node, deep=deep)
93            if key in mapping:
94                raise ConstructorError(
95                    context,
96                    node.start_mark,
97                    "found conflicting ID '{}'".format(key),
98                    key_node.start_mark,
99                )
100            mapping[key] = value
101        return mapping
102
103    def construct_scalar(self, node):
104        """
105        Verify integers and pass them in correctly is they are declared
106        as octal
107        """
108        if node.tag == "tag:yaml.org,2002:int":
109            if node.value == "0":
110                pass
111            elif node.value.startswith("0") and not node.value.startswith(("0b", "0x")):
112                node.value = node.value.lstrip("0")
113                # If value was all zeros, node.value would have been reduced to
114                # an empty string. Change it to '0'.
115                if node.value == "":
116                    node.value = "0"
117        elif node.tag == "tag:yaml.org,2002:str":
118            # If any string comes in as a quoted unicode literal, eval it into
119            # the proper unicode string type.
120            if re.match(r'^u([\'"]).+\1$', node.value, flags=re.IGNORECASE):
121                node.value = eval(node.value, {}, {})  # pylint: disable=W0123
122        return super().construct_scalar(node)
123
124    def construct_yaml_str(self, node):
125        value = self.construct_scalar(node)
126        return salt.utils.stringutils.to_unicode(value)
127
128    def fetch_plain(self):
129        """
130        Handle unicode literal strings which appear inline in the YAML
131        """
132        orig_line = self.line
133        orig_column = self.column
134        orig_pointer = self.pointer
135        try:
136            return super().fetch_plain()
137        except yaml.scanner.ScannerError as exc:
138            problem_line = self.line
139            problem_column = self.column
140            problem_pointer = self.pointer
141            if exc.problem == "found unexpected ':'":
142                # Reset to prior position
143                self.line = orig_line
144                self.column = orig_column
145                self.pointer = orig_pointer
146                if self.peek(0) == "u":
147                    # Might be a unicode literal string, check for 2nd char and
148                    # call the appropriate fetch func if it's a quote
149                    quote_char = self.peek(1)
150                    if quote_char in ("'", '"'):
151                        # Skip the "u" prefix by advancing the column and
152                        # pointer by 1
153                        self.column += 1
154                        self.pointer += 1
155                        if quote_char == "'":
156                            return self.fetch_single()
157                        else:
158                            return self.fetch_double()
159                    else:
160                        # This wasn't a unicode literal string, so the caught
161                        # exception was correct. Restore the old position and
162                        # then raise the caught exception.
163                        self.line = problem_line
164                        self.column = problem_column
165                        self.pointer = problem_pointer
166            # Raise the caught exception
167            raise exc
168
169    def flatten_mapping(self, node):
170        merge = []
171        index = 0
172        while index < len(node.value):
173            key_node, value_node = node.value[index]
174
175            if key_node.tag == "tag:yaml.org,2002:merge":
176                del node.value[index]
177                if isinstance(value_node, MappingNode):
178                    self.flatten_mapping(value_node)
179                    merge.extend(value_node.value)
180                elif isinstance(value_node, SequenceNode):
181                    submerge = []
182                    for subnode in value_node.value:
183                        if not isinstance(subnode, MappingNode):
184                            raise ConstructorError(
185                                "while constructing a mapping",
186                                node.start_mark,
187                                "expected a mapping for merging, but found {}".format(
188                                    subnode.id
189                                ),
190                                subnode.start_mark,
191                            )
192                        self.flatten_mapping(subnode)
193                        submerge.append(subnode.value)
194                    submerge.reverse()
195                    for value in submerge:
196                        merge.extend(value)
197                else:
198                    raise ConstructorError(
199                        "while constructing a mapping",
200                        node.start_mark,
201                        "expected a mapping or list of mappings for merging, but"
202                        " found {}".format(value_node.id),
203                        value_node.start_mark,
204                    )
205            elif key_node.tag == "tag:yaml.org,2002:value":
206                key_node.tag = "tag:yaml.org,2002:str"
207                index += 1
208            else:
209                index += 1
210        if merge:
211            # Here we need to discard any duplicate entries based on key_node
212            existing_nodes = [name_node.value for name_node, value_node in node.value]
213            mergeable_items = [x for x in merge if x[0].value not in existing_nodes]
214
215            node.value = mergeable_items + node.value
216
217
218def load(stream, Loader=SaltYamlSafeLoader):
219    return yaml.load(stream, Loader=Loader)
220
221
222def safe_load(stream, Loader=SaltYamlSafeLoader):
223    """
224    .. versionadded:: 2018.3.0
225
226    Helper function which automagically uses our custom loader.
227    """
228    return yaml.load(stream, Loader=Loader)
229