1from __future__ import (absolute_import, division, print_function)
2__metaclass__ = type
3
4import io
5import yaml
6
7from ansible.module_utils.six import PY3
8from ansible.parsing.yaml.loader import AnsibleLoader
9from ansible.parsing.yaml.dumper import AnsibleDumper
10
11
12class YamlTestUtils(object):
13    """Mixin class to combine with a unittest.TestCase subclass."""
14    def _loader(self, stream):
15        """Vault related tests will want to override this.
16
17        Vault cases should setup a AnsibleLoader that has the vault password."""
18        return AnsibleLoader(stream)
19
20    def _dump_stream(self, obj, stream, dumper=None):
21        """Dump to a py2-unicode or py3-string stream."""
22        if PY3:
23            return yaml.dump(obj, stream, Dumper=dumper)
24        else:
25            return yaml.dump(obj, stream, Dumper=dumper, encoding=None)
26
27    def _dump_string(self, obj, dumper=None):
28        """Dump to a py2-unicode or py3-string"""
29        if PY3:
30            return yaml.dump(obj, Dumper=dumper)
31        else:
32            return yaml.dump(obj, Dumper=dumper, encoding=None)
33
34    def _dump_load_cycle(self, obj):
35        # Each pass though a dump or load revs the 'generation'
36        # obj to yaml string
37        string_from_object_dump = self._dump_string(obj, dumper=AnsibleDumper)
38
39        # wrap a stream/file like StringIO around that yaml
40        stream_from_object_dump = io.StringIO(string_from_object_dump)
41        loader = self._loader(stream_from_object_dump)
42        # load the yaml stream to create a new instance of the object (gen 2)
43        obj_2 = loader.get_data()
44
45        # dump the gen 2 objects directory to strings
46        string_from_object_dump_2 = self._dump_string(obj_2,
47                                                      dumper=AnsibleDumper)
48
49        # The gen 1 and gen 2 yaml strings
50        self.assertEqual(string_from_object_dump, string_from_object_dump_2)
51        # the gen 1 (orig) and gen 2 py object
52        self.assertEqual(obj, obj_2)
53
54        # again! gen 3... load strings into py objects
55        stream_3 = io.StringIO(string_from_object_dump_2)
56        loader_3 = self._loader(stream_3)
57        obj_3 = loader_3.get_data()
58
59        string_from_object_dump_3 = self._dump_string(obj_3, dumper=AnsibleDumper)
60
61        self.assertEqual(obj, obj_3)
62        # should be transitive, but...
63        self.assertEqual(obj_2, obj_3)
64        self.assertEqual(string_from_object_dump, string_from_object_dump_3)
65
66    def _old_dump_load_cycle(self, obj):
67        '''Dump the passed in object to yaml, load it back up, dump again, compare.'''
68        stream = io.StringIO()
69
70        yaml_string = self._dump_string(obj, dumper=AnsibleDumper)
71        self._dump_stream(obj, stream, dumper=AnsibleDumper)
72
73        yaml_string_from_stream = stream.getvalue()
74
75        # reset stream
76        stream.seek(0)
77
78        loader = self._loader(stream)
79        # loader = AnsibleLoader(stream, vault_password=self.vault_password)
80        obj_from_stream = loader.get_data()
81
82        stream_from_string = io.StringIO(yaml_string)
83        loader2 = self._loader(stream_from_string)
84        # loader2 = AnsibleLoader(stream_from_string, vault_password=self.vault_password)
85        obj_from_string = loader2.get_data()
86
87        stream_obj_from_stream = io.StringIO()
88        stream_obj_from_string = io.StringIO()
89
90        if PY3:
91            yaml.dump(obj_from_stream, stream_obj_from_stream, Dumper=AnsibleDumper)
92            yaml.dump(obj_from_stream, stream_obj_from_string, Dumper=AnsibleDumper)
93        else:
94            yaml.dump(obj_from_stream, stream_obj_from_stream, Dumper=AnsibleDumper, encoding=None)
95            yaml.dump(obj_from_stream, stream_obj_from_string, Dumper=AnsibleDumper, encoding=None)
96
97        yaml_string_stream_obj_from_stream = stream_obj_from_stream.getvalue()
98        yaml_string_stream_obj_from_string = stream_obj_from_string.getvalue()
99
100        stream_obj_from_stream.seek(0)
101        stream_obj_from_string.seek(0)
102
103        if PY3:
104            yaml_string_obj_from_stream = yaml.dump(obj_from_stream, Dumper=AnsibleDumper)
105            yaml_string_obj_from_string = yaml.dump(obj_from_string, Dumper=AnsibleDumper)
106        else:
107            yaml_string_obj_from_stream = yaml.dump(obj_from_stream, Dumper=AnsibleDumper, encoding=None)
108            yaml_string_obj_from_string = yaml.dump(obj_from_string, Dumper=AnsibleDumper, encoding=None)
109
110        assert yaml_string == yaml_string_obj_from_stream
111        assert yaml_string == yaml_string_obj_from_stream == yaml_string_obj_from_string
112        assert (yaml_string == yaml_string_obj_from_stream == yaml_string_obj_from_string == yaml_string_stream_obj_from_stream ==
113                yaml_string_stream_obj_from_string)
114        assert obj == obj_from_stream
115        assert obj == obj_from_string
116        assert obj == yaml_string_obj_from_stream
117        assert obj == yaml_string_obj_from_string
118        assert obj == obj_from_stream == obj_from_string == yaml_string_obj_from_stream == yaml_string_obj_from_string
119        return {'obj': obj,
120                'yaml_string': yaml_string,
121                'yaml_string_from_stream': yaml_string_from_stream,
122                'obj_from_stream': obj_from_stream,
123                'obj_from_string': obj_from_string,
124                'yaml_string_obj_from_string': yaml_string_obj_from_string}
125