1# This file is part of cloud-init. See LICENSE file for license information.
2
3from cloudinit.tests import helpers
4
5from cloudinit.handlers import cloud_config
6from cloudinit.handlers import (CONTENT_START, CONTENT_END)
7
8from cloudinit import helpers as c_helpers
9from cloudinit import util
10
11import collections
12import glob
13import os
14import random
15import re
16import string
17
18SOURCE_PAT = "source*.*yaml"
19EXPECTED_PAT = "expected%s.yaml"
20TYPES = [dict, str, list, tuple, None, int]
21
22
23def _old_mergedict(src, cand):
24    """
25    Merge values from C{cand} into C{src}.
26    If C{src} has a key C{cand} will not override.
27    Nested dictionaries are merged recursively.
28    """
29    if isinstance(src, dict) and isinstance(cand, dict):
30        for (k, v) in cand.items():
31            if k not in src:
32                src[k] = v
33            else:
34                src[k] = _old_mergedict(src[k], v)
35    return src
36
37
38def _old_mergemanydict(*args):
39    out = {}
40    for a in args:
41        out = _old_mergedict(out, a)
42    return out
43
44
45def _random_str(rand):
46    base = ''
47    for _i in range(rand.randint(1, 2 ** 8)):
48        base += rand.choice(string.ascii_letters + string.digits)
49    return base
50
51
52class _NoMoreException(Exception):
53    pass
54
55
56def _make_dict(current_depth, max_depth, rand):
57    if current_depth >= max_depth:
58        raise _NoMoreException()
59    if current_depth == 0:
60        t = dict
61    else:
62        t = rand.choice(TYPES)
63    base = None
64    if t in [None]:
65        return base
66    if t in [dict, list, tuple]:
67        if t in [dict]:
68            amount = rand.randint(0, 5)
69            keys = [_random_str(rand) for _i in range(0, amount)]
70            base = {}
71            for k in keys:
72                try:
73                    base[k] = _make_dict(current_depth + 1, max_depth, rand)
74                except _NoMoreException:
75                    pass
76        elif t in [list, tuple]:
77            base = []
78            amount = rand.randint(0, 5)
79            for _i in range(0, amount):
80                try:
81                    base.append(_make_dict(current_depth + 1, max_depth, rand))
82                except _NoMoreException:
83                    pass
84            if t in [tuple]:
85                base = tuple(base)
86    elif t in [int]:
87        base = rand.randint(0, 2 ** 8)
88    elif t in [str]:
89        base = _random_str(rand)
90    return base
91
92
93def make_dict(max_depth, seed=None):
94    max_depth = max(1, max_depth)
95    rand = random.Random(seed)
96    return _make_dict(0, max_depth, rand)
97
98
99class TestSimpleRun(helpers.ResourceUsingTestCase):
100    def _load_merge_files(self):
101        merge_root = helpers.resourceLocation('merge_sources')
102        tests = []
103        source_ids = collections.defaultdict(list)
104        expected_files = {}
105        for fn in glob.glob(os.path.join(merge_root, SOURCE_PAT)):
106            base_fn = os.path.basename(fn)
107            file_id = re.match(r"source(\d+)\-(\d+)[.]yaml", base_fn)
108            if not file_id:
109                raise IOError("File %s does not have a numeric identifier"
110                              % (fn))
111            file_id = int(file_id.group(1))
112            source_ids[file_id].append(fn)
113            expected_fn = os.path.join(merge_root, EXPECTED_PAT % (file_id))
114            if not os.path.isfile(expected_fn):
115                raise IOError("No expected file found at %s" % (expected_fn))
116            expected_files[file_id] = expected_fn
117        for i in sorted(source_ids.keys()):
118            source_file_contents = []
119            for fn in sorted(source_ids[i]):
120                source_file_contents.append([fn, util.load_file(fn)])
121            expected = util.load_yaml(util.load_file(expected_files[i]))
122            entry = [source_file_contents, [expected, expected_files[i]]]
123            tests.append(entry)
124        return tests
125
126    def test_seed_runs(self):
127        test_dicts = []
128        for i in range(1, 10):
129            base_dicts = []
130            for j in range(1, 10):
131                base_dicts.append(make_dict(5, i * j))
132            test_dicts.append(base_dicts)
133        for test in test_dicts:
134            c = _old_mergemanydict(*test)
135            d = util.mergemanydict(test)
136            self.assertEqual(c, d)
137
138    def test_merge_cc_samples(self):
139        tests = self._load_merge_files()
140        paths = c_helpers.Paths({})
141        cc_handler = cloud_config.CloudConfigPartHandler(paths)
142        cc_handler.cloud_fn = None
143        for (payloads, (expected_merge, expected_fn)) in tests:
144            cc_handler.handle_part(None, CONTENT_START, None,
145                                   None, None, None)
146            merging_fns = []
147            for (fn, contents) in payloads:
148                cc_handler.handle_part(None, None, "%s.yaml" % (fn),
149                                       contents, None, {})
150                merging_fns.append(fn)
151            merged_buf = cc_handler.cloud_buf
152            cc_handler.handle_part(None, CONTENT_END, None,
153                                   None, None, None)
154            fail_msg = "Equality failure on checking %s with %s: %s != %s"
155            fail_msg = fail_msg % (expected_fn,
156                                   ",".join(merging_fns), merged_buf,
157                                   expected_merge)
158            self.assertEqual(expected_merge, merged_buf, msg=fail_msg)
159
160    def test_compat_merges_dict(self):
161        a = {
162            '1': '2',
163            'b': 'c',
164        }
165        b = {
166            'b': 'e',
167        }
168        c = _old_mergedict(a, b)
169        d = util.mergemanydict([a, b])
170        self.assertEqual(c, d)
171
172    def test_compat_merges_dict2(self):
173        a = {
174            'Blah': 1,
175            'Blah2': 2,
176            'Blah3': 3,
177        }
178        b = {
179            'Blah': 1,
180            'Blah2': 2,
181            'Blah3': [1],
182        }
183        c = _old_mergedict(a, b)
184        d = util.mergemanydict([a, b])
185        self.assertEqual(c, d)
186
187    def test_compat_merges_list(self):
188        a = {'b': [1, 2, 3]}
189        b = {'b': [4, 5]}
190        c = {'b': [6, 7]}
191        e = _old_mergemanydict(a, b, c)
192        f = util.mergemanydict([a, b, c])
193        self.assertEqual(e, f)
194
195    def test_compat_merges_str(self):
196        a = {'b': "hi"}
197        b = {'b': "howdy"}
198        c = {'b': "hallo"}
199        e = _old_mergemanydict(a, b, c)
200        f = util.mergemanydict([a, b, c])
201        self.assertEqual(e, f)
202
203    def test_compat_merge_sub_dict(self):
204        a = {
205            '1': '2',
206            'b': {
207                'f': 'g',
208                'e': 'c',
209                'h': 'd',
210                'hh': {
211                    '1': 2,
212                },
213            }
214        }
215        b = {
216            'b': {
217                'e': 'c',
218                'hh': {
219                    '3': 4,
220                }
221            }
222        }
223        c = _old_mergedict(a, b)
224        d = util.mergemanydict([a, b])
225        self.assertEqual(c, d)
226
227    def test_compat_merge_sub_dict2(self):
228        a = {
229            '1': '2',
230            'b': {
231                'f': 'g',
232            }
233        }
234        b = {
235            'b': {
236                'e': 'c',
237            }
238        }
239        c = _old_mergedict(a, b)
240        d = util.mergemanydict([a, b])
241        self.assertEqual(c, d)
242
243    def test_compat_merge_sub_list(self):
244        a = {
245            '1': '2',
246            'b': {
247                'f': ['1'],
248            }
249        }
250        b = {
251            'b': {
252                'f': [],
253            }
254        }
255        c = _old_mergedict(a, b)
256        d = util.mergemanydict([a, b])
257        self.assertEqual(c, d)
258
259# vi: ts=4 expandtab
260