1"""
2    YAMLEX is a format that allows for things like sls files to be
3    more intuitive.
4
5    It's an extension of YAML that implements all the salt magic:
6    - it implies omap for any dict like.
7    - it implies that string like data are str, not unicode
8    - ...
9
10    For example, the file `states.sls` has this contents:
11
12    .. code-block:: yaml
13
14        foo:
15          bar: 42
16          baz: [1, 2, 3]
17
18    The file can be parsed into Python like this
19
20    .. code-block:: python
21
22        from salt.serializers import yamlex
23
24        with open("state.sls", "r") as stream:
25            obj = yamlex.deserialize(stream)
26
27    Check that ``obj`` is an OrderedDict
28
29    .. code-block:: python
30
31        from salt.utils.odict import OrderedDict
32
33        assert isinstance(obj, dict)
34        assert isinstance(obj, OrderedDict)
35
36
37    yamlex `__repr__` and `__str__` objects' methods render YAML understandable
38    string. It means that they are template friendly.
39
40
41    .. code-block:: python
42
43        print "{0}".format(obj)
44
45    returns:
46
47    ::
48
49        {foo: {bar: 42, baz: [1, 2, 3]}}
50
51    and they are still valid YAML:
52
53    .. code-block:: python
54
55        from salt.serializers import yaml
56
57        yml_obj = yaml.deserialize(str(obj))
58        assert yml_obj == obj
59
60    yamlex implements also custom tags:
61
62    !aggregate
63
64         this tag allows structures aggregation.
65
66        For example:
67
68
69        .. code-block:: yaml
70
71            placeholder: !aggregate foo
72            placeholder: !aggregate bar
73            placeholder: !aggregate baz
74
75        is rendered as
76
77        .. code-block:: yaml
78
79            placeholder: [foo, bar, baz]
80
81    !reset
82
83         this tag flushes the computing value.
84
85        .. code-block:: yaml
86
87            placeholder: {!aggregate foo: {foo: 42}}
88            placeholder: {!aggregate foo: {bar: null}}
89            !reset placeholder: {!aggregate foo: {baz: inga}}
90
91        is roughly equivalent to
92
93        .. code-block:: yaml
94
95            placeholder: {!aggregate foo: {baz: inga}}
96
97    Document is defacto an aggregate mapping.
98"""
99# pylint: disable=invalid-name,no-member,missing-docstring,no-self-use
100# pylint: disable=too-few-public-methods,too-many-public-methods
101import collections
102import copy
103import datetime
104import logging
105from typing import TextIO
106
107import yaml
108from yaml.constructor import ConstructorError
109from yaml.nodes import MappingNode
110
111from .aggregation import aggregate
112from .aggregation import Map
113from .aggregation import Sequence
114
115__all__ = ["deserialize", "serialize", "available"]
116
117log = logging.getLogger(__name__)
118
119available = True
120
121# prefer C bindings over python when available
122BaseLoader = getattr(yaml, "CSafeLoader", yaml.SafeLoader)
123# CSafeDumper causes repr errors in python3, so use the pure Python one
124try:
125    # Depending on how PyYAML was built, yaml.SafeDumper may actually be
126    # yaml.cyaml.CSafeDumper (i.e. the C dumper instead of pure Python).
127    BaseDumper = yaml.dumper.SafeDumper
128except AttributeError:
129    # Here just in case, but yaml.dumper.SafeDumper should always exist
130    BaseDumper = yaml.SafeDumper
131
132ERROR_MAP = {
133    ("found character '\\t' " "that cannot start any token"): "Illegal tab character"
134}
135
136
137def deserialize(stream_or_string: str or TextIO, **options):
138    """
139    Deserialize any string of stream like object into a Python data structure.
140
141    :param stream_or_string: stream or string to deserialize.
142    :param options: options given to lower yaml module.
143    """
144
145    options.setdefault("Loader", Loader)
146    return yaml.load(stream_or_string, **options)
147
148
149def serialize(obj, **options):
150    """
151    Serialize Python data to YAML.
152
153    :param obj: the data structure to serialize
154    :param options: options given to lower yaml module.
155    """
156
157    options.setdefault("Dumper", Dumper)
158    options.setdefault("default_flow_style", None)
159    response = yaml.dump(obj, **options)
160    if response.endswith("\n...\n"):
161        return response[:-5]
162    if response.endswith("\n"):
163        return response[:-1]
164    return response
165
166
167class Loader(BaseLoader):  # pylint: disable=W0232
168    """
169    Create a custom YAML loader that uses the custom constructor. This allows
170    for the YAML loading defaults to be manipulated based on needs within salt
171    to make things like sls file more intuitive.
172    """
173
174    DEFAULT_SCALAR_TAG = "tag:yaml.org,2002:str"
175    DEFAULT_SEQUENCE_TAG = "tag:yaml.org,2002:seq"
176    DEFAULT_MAPPING_TAG = "tag:yaml.org,2002:omap"
177
178    def compose_document(self):
179        node = BaseLoader.compose_document(self)
180        node.tag = "!aggregate"
181        return node
182
183    def construct_yaml_omap(self, node):
184        """
185        Build the SLSMap
186        """
187        sls_map = SLSMap()
188        if not isinstance(node, MappingNode):
189            raise ConstructorError(
190                None,
191                None,
192                f"expected a mapping node, but found {node.id}",
193                node.start_mark,
194            )
195
196        self.flatten_mapping(node)
197
198        for key_node, value_node in node.value:
199
200            # !reset instruction applies on document only.
201            # It tells to reset previous decoded value for this present key.
202            reset = key_node.tag == "!reset"
203
204            # even if !aggregate tag apply only to values and not keys
205            # it's a reason to act as a such nazi.
206            if key_node.tag == "!aggregate":
207                log.warning("!aggregate applies on values only, not on keys")
208                value_node.tag = key_node.tag
209                key_node.tag = self.resolve_sls_tag(key_node)[0]
210
211            key = self.construct_object(key_node, deep=False)
212            try:
213                hash(key)
214            except TypeError:
215                err = (
216                    "While constructing a mapping {} found unacceptable " "key {}"
217                ).format(node.start_mark, key_node.start_mark)
218                raise ConstructorError(err)
219            value = self.construct_object(value_node, deep=False)
220            if key in sls_map and not reset:
221                value = merge_recursive(sls_map[key], value)
222            sls_map[key] = value
223        return sls_map
224
225    def construct_sls_str(self, node):
226        """
227        Build the SLSString.
228        """
229
230        # Ensure obj is str, not py2 unicode or py3 bytes
231        obj = self.construct_scalar(node)
232        return SLSString(obj)
233
234    def construct_sls_int(self, node):
235        """
236        Verify integers and pass them in correctly is they are declared
237        as octal
238        """
239        if node.value == "0":
240            pass
241        elif node.value.startswith("0") and not node.value.startswith(("0b", "0x")):
242            node.value = node.value.lstrip("0")
243            # If value was all zeros, node.value would have been reduced to
244            # an empty string. Change it to '0'.
245            if node.value == "":
246                node.value = "0"
247        return int(node.value)
248
249    def construct_sls_aggregate(self, node):
250        try:
251            tag, deep = self.resolve_sls_tag(node)
252        except Exception:  # pylint: disable=broad-except
253            raise ConstructorError("unable to build reset")
254
255        node = copy.copy(node)
256        node.tag = tag
257        obj = self.construct_object(node, deep)
258        if obj is None:
259            return AggregatedSequence()
260        elif tag == self.DEFAULT_MAPPING_TAG:
261            return AggregatedMap(obj)
262        elif tag == self.DEFAULT_SEQUENCE_TAG:
263            return AggregatedSequence(obj)
264        return AggregatedSequence([obj])
265
266    def construct_sls_reset(self, node):
267        try:
268            tag, deep = self.resolve_sls_tag(node)
269        except Exception:  # pylint: disable=broad-except
270            raise ConstructorError("unable to build reset")
271
272        node = copy.copy(node)
273        node.tag = tag
274
275        return self.construct_object(node, deep)
276
277    def resolve_sls_tag(self, node):
278        if isinstance(node, yaml.nodes.ScalarNode):
279            # search implicit tag
280            tag = self.resolve(yaml.nodes.ScalarNode, node.value, [True, True])
281            deep = False
282        elif isinstance(node, yaml.nodes.SequenceNode):
283            tag = self.DEFAULT_SEQUENCE_TAG
284            deep = True
285        elif isinstance(node, yaml.nodes.MappingNode):
286            tag = self.DEFAULT_MAPPING_TAG
287            deep = True
288        else:
289            raise ConstructorError("unable to resolve tag")
290        return tag, deep
291
292
293Loader.add_constructor("!aggregate", Loader.construct_sls_aggregate)  # custom type
294Loader.add_constructor("!reset", Loader.construct_sls_reset)  # custom type
295Loader.add_constructor(
296    "tag:yaml.org,2002:omap", Loader.construct_yaml_omap
297)  # our overwrite
298Loader.add_constructor(
299    "tag:yaml.org,2002:str", Loader.construct_sls_str
300)  # our overwrite
301Loader.add_constructor(
302    "tag:yaml.org,2002:int", Loader.construct_sls_int
303)  # our overwrite
304Loader.add_multi_constructor("tag:yaml.org,2002:null", Loader.construct_yaml_null)
305Loader.add_multi_constructor("tag:yaml.org,2002:bool", Loader.construct_yaml_bool)
306Loader.add_multi_constructor("tag:yaml.org,2002:float", Loader.construct_yaml_float)
307Loader.add_multi_constructor("tag:yaml.org,2002:binary", Loader.construct_yaml_binary)
308Loader.add_multi_constructor(
309    "tag:yaml.org,2002:timestamp", Loader.construct_yaml_timestamp
310)
311Loader.add_multi_constructor("tag:yaml.org,2002:pairs", Loader.construct_yaml_pairs)
312Loader.add_multi_constructor("tag:yaml.org,2002:set", Loader.construct_yaml_set)
313Loader.add_multi_constructor("tag:yaml.org,2002:seq", Loader.construct_yaml_seq)
314Loader.add_multi_constructor("tag:yaml.org,2002:map", Loader.construct_yaml_map)
315
316
317class SLSMap(collections.OrderedDict):
318    """
319    Ensures that dict str() and repr() are YAML friendly.
320
321    .. code-block:: python
322
323        from collections import OrderedDict
324
325        mapping = OrderedDict([("a", "b"), ("c", None)])
326        print(mapping)
327        # OrderedDict([('a', 'b'), ('c', None)])
328
329        sls_map = SLSMap(mapping)
330        print(sls_map.__str__())
331        # {a: b, c: null}
332
333    """
334
335    def __str__(self):
336        return serialize(self, default_flow_style=True)
337
338    def __repr__(self, _repr_running=None):
339        return serialize(self, default_flow_style=True)
340
341
342class SLSString(str):
343    """
344    Ensures that str str() and repr() are YAML friendly.
345
346    .. code-block:: python
347
348        scalar = str("foo")
349        print "foo"
350        # foo
351
352        sls_scalar = SLSString(scalar)
353        print (sls_scalar)
354        # "foo"
355
356    """
357
358    def __str__(self):
359        return serialize(self, default_style='"')
360
361    def __repr__(self):
362        return serialize(self, default_style='"')
363
364
365class AggregatedMap(SLSMap, Map):
366    pass
367
368
369class AggregatedSequence(Sequence):
370    pass
371
372
373class Dumper(BaseDumper):  # pylint: disable=W0232
374    """
375    sls dumper.
376    """
377
378    def represent_odict(self, data):
379        return self.represent_mapping("tag:yaml.org,2002:map", list(data.items()))
380
381
382Dumper.add_multi_representer(type(None), Dumper.represent_none)
383Dumper.add_multi_representer(bytes, Dumper.represent_binary)
384Dumper.add_multi_representer(str, Dumper.represent_str)
385Dumper.add_multi_representer(bool, Dumper.represent_bool)
386Dumper.add_multi_representer(int, Dumper.represent_int)
387Dumper.add_multi_representer(float, Dumper.represent_float)
388Dumper.add_multi_representer(list, Dumper.represent_list)
389Dumper.add_multi_representer(tuple, Dumper.represent_list)
390Dumper.add_multi_representer(
391    dict, Dumper.represent_odict
392)  # make every dict like obj to be represented as a map
393Dumper.add_multi_representer(set, Dumper.represent_set)
394Dumper.add_multi_representer(datetime.date, Dumper.represent_date)
395Dumper.add_multi_representer(datetime.datetime, Dumper.represent_datetime)
396Dumper.add_multi_representer(None, Dumper.represent_undefined)
397
398
399def merge_recursive(obj_a, obj_b, level: bool or int = False):
400    """
401    Merge obj_b into obj_a.
402    """
403    return aggregate(
404        obj_a, obj_b, level, map_class=AggregatedMap, sequence_class=AggregatedSequence
405    )
406