1# This file is part of Ansible
2# -*- coding: utf-8 -*-
3#
4#
5# Ansible is free software: you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation, either version 3 of the License, or
8# (at your option) any later version.
9#
10# Ansible is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.
17#
18
19# Make coding more python3-ish
20from __future__ import (absolute_import, division, print_function)
21__metaclass__ = type
22
23from collections import defaultdict
24import pprint
25
26# for testing
27from units.compat import unittest
28
29from ansible.module_utils.facts import collector
30
31from ansible.module_utils.facts import default_collectors
32
33
34class TestFindCollectorsForPlatform(unittest.TestCase):
35    def test(self):
36        compat_platforms = [{'system': 'Generic'}]
37        res = collector.find_collectors_for_platform(default_collectors.collectors,
38                                                     compat_platforms)
39        for coll_class in res:
40            self.assertIn(coll_class._platform, ('Generic'))
41
42    def test_linux(self):
43        compat_platforms = [{'system': 'Linux'}]
44        res = collector.find_collectors_for_platform(default_collectors.collectors,
45                                                     compat_platforms)
46        for coll_class in res:
47            self.assertIn(coll_class._platform, ('Linux'))
48
49    def test_linux_or_generic(self):
50        compat_platforms = [{'system': 'Generic'}, {'system': 'Linux'}]
51        res = collector.find_collectors_for_platform(default_collectors.collectors,
52                                                     compat_platforms)
53        for coll_class in res:
54            self.assertIn(coll_class._platform, ('Generic', 'Linux'))
55
56
57class TestSelectCollectorNames(unittest.TestCase):
58
59    def _assert_equal_detail(self, obj1, obj2, msg=None):
60        msg = 'objects are not equal\n%s\n\n!=\n\n%s' % (pprint.pformat(obj1), pprint.pformat(obj2))
61        return self.assertEqual(obj1, obj2, msg)
62
63    def test(self):
64        collector_names = ['distribution', 'all_ipv4_addresses',
65                           'local', 'pkg_mgr']
66        all_fact_subsets = self._all_fact_subsets()
67        res = collector.select_collector_classes(collector_names,
68                                                 all_fact_subsets)
69
70        expected = [default_collectors.DistributionFactCollector,
71                    default_collectors.PkgMgrFactCollector]
72
73        self._assert_equal_detail(res, expected)
74
75    def test_default_collectors(self):
76        platform_info = {'system': 'Generic'}
77        compat_platforms = [platform_info]
78        collectors_for_platform = collector.find_collectors_for_platform(default_collectors.collectors,
79                                                                         compat_platforms)
80
81        all_fact_subsets, aliases_map = collector.build_fact_id_to_collector_map(collectors_for_platform)
82
83        all_valid_subsets = frozenset(all_fact_subsets.keys())
84        collector_names = collector.get_collector_names(valid_subsets=all_valid_subsets,
85                                                        aliases_map=aliases_map,
86                                                        platform_info=platform_info)
87        complete_collector_names = collector._solve_deps(collector_names, all_fact_subsets)
88
89        dep_map = collector.build_dep_data(complete_collector_names, all_fact_subsets)
90
91        ordered_deps = collector.tsort(dep_map)
92        ordered_collector_names = [x[0] for x in ordered_deps]
93
94        res = collector.select_collector_classes(ordered_collector_names,
95                                                 all_fact_subsets)
96
97        self.assertTrue(res.index(default_collectors.ServiceMgrFactCollector) >
98                        res.index(default_collectors.DistributionFactCollector),
99                        res)
100        self.assertTrue(res.index(default_collectors.ServiceMgrFactCollector) >
101                        res.index(default_collectors.PlatformFactCollector),
102                        res)
103
104    def _all_fact_subsets(self, data=None):
105        all_fact_subsets = defaultdict(list)
106        _data = {'pkg_mgr': [default_collectors.PkgMgrFactCollector],
107                 'distribution': [default_collectors.DistributionFactCollector],
108                 'network': [default_collectors.LinuxNetworkCollector]}
109        data = data or _data
110        for key, value in data.items():
111            all_fact_subsets[key] = value
112        return all_fact_subsets
113
114
115class TestGetCollectorNames(unittest.TestCase):
116    def test_none(self):
117        res = collector.get_collector_names()
118        self.assertIsInstance(res, set)
119        self.assertEqual(res, set([]))
120
121    def test_empty_sets(self):
122        res = collector.get_collector_names(valid_subsets=frozenset([]),
123                                            minimal_gather_subset=frozenset([]),
124                                            gather_subset=[])
125        self.assertIsInstance(res, set)
126        self.assertEqual(res, set([]))
127
128    def test_empty_valid_and_min_with_all_gather_subset(self):
129        res = collector.get_collector_names(valid_subsets=frozenset([]),
130                                            minimal_gather_subset=frozenset([]),
131                                            gather_subset=['all'])
132        self.assertIsInstance(res, set)
133        self.assertEqual(res, set([]))
134
135    def test_one_valid_with_all_gather_subset(self):
136        valid_subsets = frozenset(['my_fact'])
137        res = collector.get_collector_names(valid_subsets=valid_subsets,
138                                            minimal_gather_subset=frozenset([]),
139                                            gather_subset=['all'])
140        self.assertIsInstance(res, set)
141        self.assertEqual(res, set(['my_fact']))
142
143    def _compare_res(self, gather_subset1, gather_subset2,
144                     valid_subsets=None, min_subset=None):
145
146        valid_subsets = valid_subsets or frozenset()
147        minimal_gather_subset = min_subset or frozenset()
148
149        res1 = collector.get_collector_names(valid_subsets=valid_subsets,
150                                             minimal_gather_subset=minimal_gather_subset,
151                                             gather_subset=gather_subset1)
152
153        res2 = collector.get_collector_names(valid_subsets=valid_subsets,
154                                             minimal_gather_subset=minimal_gather_subset,
155                                             gather_subset=gather_subset2)
156
157        return res1, res2
158
159    def test_not_all_other_order(self):
160        valid_subsets = frozenset(['min_fact', 'something_else', 'whatever'])
161        minimal_gather_subset = frozenset(['min_fact'])
162
163        res1, res2 = self._compare_res(['!all', 'whatever'],
164                                       ['whatever', '!all'],
165                                       valid_subsets=valid_subsets,
166                                       min_subset=minimal_gather_subset)
167        self.assertEqual(res1, res2)
168        self.assertEqual(res1, set(['min_fact', 'whatever']))
169
170    def test_not_all_other_order_min(self):
171        valid_subsets = frozenset(['min_fact', 'something_else', 'whatever'])
172        minimal_gather_subset = frozenset(['min_fact'])
173
174        res1, res2 = self._compare_res(['!min_fact', 'whatever'],
175                                       ['whatever', '!min_fact'],
176                                       valid_subsets=valid_subsets,
177                                       min_subset=minimal_gather_subset)
178        self.assertEqual(res1, res2)
179        self.assertEqual(res1, set(['whatever']))
180
181    def test_one_minimal_with_all_gather_subset(self):
182        my_fact = 'my_fact'
183        valid_subsets = frozenset([my_fact])
184        minimal_gather_subset = valid_subsets
185
186        res = collector.get_collector_names(valid_subsets=valid_subsets,
187                                            minimal_gather_subset=minimal_gather_subset,
188                                            gather_subset=['all'])
189        self.assertIsInstance(res, set)
190        self.assertEqual(res, set(['my_fact']))
191
192    def test_with_all_gather_subset(self):
193        valid_subsets = frozenset(['my_fact', 'something_else', 'whatever'])
194        minimal_gather_subset = frozenset(['my_fact'])
195
196        # even with '!all', the minimal_gather_subset should be returned
197        res = collector.get_collector_names(valid_subsets=valid_subsets,
198                                            minimal_gather_subset=minimal_gather_subset,
199                                            gather_subset=['all'])
200        self.assertIsInstance(res, set)
201        self.assertEqual(res, set(['my_fact', 'something_else', 'whatever']))
202
203    def test_one_minimal_with_not_all_gather_subset(self):
204        valid_subsets = frozenset(['my_fact', 'something_else', 'whatever'])
205        minimal_gather_subset = frozenset(['my_fact'])
206
207        # even with '!all', the minimal_gather_subset should be returned
208        res = collector.get_collector_names(valid_subsets=valid_subsets,
209                                            minimal_gather_subset=minimal_gather_subset,
210                                            gather_subset=['!all'])
211        self.assertIsInstance(res, set)
212        self.assertEqual(res, set(['my_fact']))
213
214    def test_gather_subset_excludes(self):
215        valid_subsets = frozenset(['my_fact', 'something_else', 'whatever'])
216        minimal_gather_subset = frozenset(['min_fact', 'min_another'])
217
218        # even with '!all', the minimal_gather_subset should be returned
219        res = collector.get_collector_names(valid_subsets=valid_subsets,
220                                            minimal_gather_subset=minimal_gather_subset,
221                                            # gather_subset=set(['all', '!my_fact', '!whatever']))
222                                            # gather_subset=['all', '!my_fact', '!whatever'])
223                                            gather_subset=['!min_fact', '!whatever'])
224        self.assertIsInstance(res, set)
225        # min_another is in minimal_gather_subset, so always returned
226        self.assertEqual(res, set(['min_another']))
227
228    def test_gather_subset_excludes_ordering(self):
229        valid_subsets = frozenset(['my_fact', 'something_else', 'whatever'])
230        minimal_gather_subset = frozenset(['my_fact'])
231
232        res = collector.get_collector_names(valid_subsets=valid_subsets,
233                                            minimal_gather_subset=minimal_gather_subset,
234                                            gather_subset=['!all', 'whatever'])
235        self.assertIsInstance(res, set)
236        # excludes are higher precedence than includes, so !all excludes everything
237        # and then minimal_gather_subset is added. so '!all', 'other' == '!all'
238        self.assertEqual(res, set(['my_fact', 'whatever']))
239
240    def test_gather_subset_excludes_min(self):
241        valid_subsets = frozenset(['min_fact', 'something_else', 'whatever'])
242        minimal_gather_subset = frozenset(['min_fact'])
243
244        res = collector.get_collector_names(valid_subsets=valid_subsets,
245                                            minimal_gather_subset=minimal_gather_subset,
246                                            gather_subset=['whatever', '!min'])
247        self.assertIsInstance(res, set)
248        # excludes are higher precedence than includes, so !all excludes everything
249        # and then minimal_gather_subset is added. so '!all', 'other' == '!all'
250        self.assertEqual(res, set(['whatever']))
251
252    def test_gather_subset_excludes_min_and_all(self):
253        valid_subsets = frozenset(['min_fact', 'something_else', 'whatever'])
254        minimal_gather_subset = frozenset(['min_fact'])
255
256        res = collector.get_collector_names(valid_subsets=valid_subsets,
257                                            minimal_gather_subset=minimal_gather_subset,
258                                            gather_subset=['whatever', '!all', '!min'])
259        self.assertIsInstance(res, set)
260        # excludes are higher precedence than includes, so !all excludes everything
261        # and then minimal_gather_subset is added. so '!all', 'other' == '!all'
262        self.assertEqual(res, set(['whatever']))
263
264    def test_invaid_gather_subset(self):
265        valid_subsets = frozenset(['my_fact', 'something_else'])
266        minimal_gather_subset = frozenset(['my_fact'])
267
268        self.assertRaisesRegexp(TypeError,
269                                r'Bad subset .* given to Ansible.*allowed\:.*all,.*my_fact.*',
270                                collector.get_collector_names,
271                                valid_subsets=valid_subsets,
272                                minimal_gather_subset=minimal_gather_subset,
273                                gather_subset=['my_fact', 'not_a_valid_gather_subset'])
274
275
276class TestFindUnresolvedRequires(unittest.TestCase):
277    def test(self):
278        names = ['network', 'virtual', 'env']
279        all_fact_subsets = {'env': [default_collectors.EnvFactCollector],
280                            'network': [default_collectors.LinuxNetworkCollector],
281                            'virtual': [default_collectors.LinuxVirtualCollector]}
282        res = collector.find_unresolved_requires(names, all_fact_subsets)
283        # pprint.pprint(res)
284
285        self.assertIsInstance(res, set)
286        self.assertEqual(res, set(['platform', 'distribution']))
287
288    def test_resolved(self):
289        names = ['network', 'virtual', 'env', 'platform', 'distribution']
290        all_fact_subsets = {'env': [default_collectors.EnvFactCollector],
291                            'network': [default_collectors.LinuxNetworkCollector],
292                            'distribution': [default_collectors.DistributionFactCollector],
293                            'platform': [default_collectors.PlatformFactCollector],
294                            'virtual': [default_collectors.LinuxVirtualCollector]}
295        res = collector.find_unresolved_requires(names, all_fact_subsets)
296        # pprint.pprint(res)
297
298        self.assertIsInstance(res, set)
299        self.assertEqual(res, set())
300
301
302class TestBuildDepData(unittest.TestCase):
303    def test(self):
304        names = ['network', 'virtual', 'env']
305        all_fact_subsets = {'env': [default_collectors.EnvFactCollector],
306                            'network': [default_collectors.LinuxNetworkCollector],
307                            'virtual': [default_collectors.LinuxVirtualCollector]}
308        res = collector.build_dep_data(names, all_fact_subsets)
309
310        # pprint.pprint(dict(res))
311        self.assertIsInstance(res, defaultdict)
312        self.assertEqual(dict(res),
313                         {'network': set(['platform', 'distribution']),
314                          'virtual': set(),
315                          'env': set()})
316
317
318class TestSolveDeps(unittest.TestCase):
319    def test_no_solution(self):
320        unresolved = set(['required_thing1', 'required_thing2'])
321        all_fact_subsets = {'env': [default_collectors.EnvFactCollector],
322                            'network': [default_collectors.LinuxNetworkCollector],
323                            'virtual': [default_collectors.LinuxVirtualCollector]}
324
325        self.assertRaises(collector.CollectorNotFoundError,
326                          collector._solve_deps,
327                          unresolved,
328                          all_fact_subsets)
329
330    def test(self):
331        unresolved = set(['env', 'network'])
332        all_fact_subsets = {'env': [default_collectors.EnvFactCollector],
333                            'network': [default_collectors.LinuxNetworkCollector],
334                            'virtual': [default_collectors.LinuxVirtualCollector],
335                            'platform': [default_collectors.PlatformFactCollector],
336                            'distribution': [default_collectors.DistributionFactCollector]}
337        res = collector.resolve_requires(unresolved, all_fact_subsets)
338
339        res = collector._solve_deps(unresolved, all_fact_subsets)
340
341        self.assertIsInstance(res, set)
342        for goal in unresolved:
343            self.assertIn(goal, res)
344
345
346class TestResolveRequires(unittest.TestCase):
347    def test_no_resolution(self):
348        unresolved = ['required_thing1', 'required_thing2']
349        all_fact_subsets = {'env': [default_collectors.EnvFactCollector],
350                            'network': [default_collectors.LinuxNetworkCollector],
351                            'virtual': [default_collectors.LinuxVirtualCollector]}
352        self.assertRaisesRegexp(collector.UnresolvedFactDep,
353                                'unresolved fact dep.*required_thing2',
354                                collector.resolve_requires,
355                                unresolved, all_fact_subsets)
356
357    def test(self):
358        unresolved = ['env', 'network']
359        all_fact_subsets = {'env': [default_collectors.EnvFactCollector],
360                            'network': [default_collectors.LinuxNetworkCollector],
361                            'virtual': [default_collectors.LinuxVirtualCollector]}
362        res = collector.resolve_requires(unresolved, all_fact_subsets)
363        for goal in unresolved:
364            self.assertIn(goal, res)
365
366    def test_exception(self):
367        unresolved = ['required_thing1']
368        all_fact_subsets = {}
369        try:
370            collector.resolve_requires(unresolved, all_fact_subsets)
371        except collector.UnresolvedFactDep as exc:
372            self.assertIn(unresolved[0], '%s' % exc)
373
374
375class TestTsort(unittest.TestCase):
376    def test(self):
377        dep_map = {'network': set(['distribution', 'platform']),
378                   'virtual': set(),
379                   'platform': set(['what_platform_wants']),
380                   'what_platform_wants': set(),
381                   'network_stuff': set(['network'])}
382
383        res = collector.tsort(dep_map)
384        # pprint.pprint(res)
385
386        self.assertIsInstance(res, list)
387        names = [x[0] for x in res]
388        self.assertTrue(names.index('network_stuff') > names.index('network'))
389        self.assertTrue(names.index('platform') > names.index('what_platform_wants'))
390        self.assertTrue(names.index('network') > names.index('platform'))
391
392    def test_cycles(self):
393        dep_map = {'leaf1': set(),
394                   'leaf2': set(),
395                   'node1': set(['node2']),
396                   'node2': set(['node3']),
397                   'node3': set(['node1'])}
398
399        self.assertRaises(collector.CycleFoundInFactDeps,
400                          collector.tsort,
401                          dep_map)
402
403    def test_just_nodes(self):
404        dep_map = {'leaf1': set(),
405                   'leaf4': set(),
406                   'leaf3': set(),
407                   'leaf2': set()}
408
409        res = collector.tsort(dep_map)
410        self.assertIsInstance(res, list)
411        names = [x[0] for x in res]
412        # not a lot to assert here, any order of the
413        # results is valid
414        self.assertEqual(set(names), set(dep_map.keys()))
415
416    def test_self_deps(self):
417        dep_map = {'node1': set(['node1']),
418                   'node2': set(['node2'])}
419        self.assertRaises(collector.CycleFoundInFactDeps,
420                          collector.tsort,
421                          dep_map)
422
423    def test_unsolvable(self):
424        dep_map = {'leaf1': set(),
425                   'node2': set(['leaf2'])}
426
427        res = collector.tsort(dep_map)
428        self.assertIsInstance(res, list)
429        names = [x[0] for x in res]
430        self.assertEqual(set(names), set(dep_map.keys()))
431
432    def test_chain(self):
433        dep_map = {'leaf1': set(['leaf2']),
434                   'leaf2': set(['leaf3']),
435                   'leaf3': set(['leaf4']),
436                   'leaf4': set(),
437                   'leaf5': set(['leaf1'])}
438        res = collector.tsort(dep_map)
439        self.assertIsInstance(res, list)
440        names = [x[0] for x in res]
441        self.assertEqual(set(names), set(dep_map.keys()))
442
443    def test_multi_pass(self):
444        dep_map = {'leaf1': set(),
445                   'leaf2': set(['leaf3', 'leaf1', 'leaf4', 'leaf5']),
446                   'leaf3': set(['leaf4', 'leaf1']),
447                   'leaf4': set(['leaf1']),
448                   'leaf5': set(['leaf1'])}
449        res = collector.tsort(dep_map)
450        self.assertIsInstance(res, list)
451        names = [x[0] for x in res]
452        self.assertEqual(set(names), set(dep_map.keys()))
453        self.assertTrue(names.index('leaf1') < names.index('leaf2'))
454        for leaf in ('leaf2', 'leaf3', 'leaf4', 'leaf5'):
455            self.assertTrue(names.index('leaf1') < names.index(leaf))
456
457
458class TestCollectorClassesFromGatherSubset(unittest.TestCase):
459    maxDiff = None
460
461    def _classes(self,
462                 all_collector_classes=None,
463                 valid_subsets=None,
464                 minimal_gather_subset=None,
465                 gather_subset=None,
466                 gather_timeout=None,
467                 platform_info=None):
468        platform_info = platform_info or {'system': 'Linux'}
469        return collector.collector_classes_from_gather_subset(all_collector_classes=all_collector_classes,
470                                                              valid_subsets=valid_subsets,
471                                                              minimal_gather_subset=minimal_gather_subset,
472                                                              gather_subset=gather_subset,
473                                                              gather_timeout=gather_timeout,
474                                                              platform_info=platform_info)
475
476    def test_no_args(self):
477        res = self._classes()
478        self.assertIsInstance(res, list)
479        self.assertEqual(res, [])
480
481    def test_not_all(self):
482        res = self._classes(all_collector_classes=default_collectors.collectors,
483                            gather_subset=['!all'])
484        self.assertIsInstance(res, list)
485        self.assertEqual(res, [])
486
487    def test_all(self):
488        res = self._classes(all_collector_classes=default_collectors.collectors,
489                            gather_subset=['all'])
490        self.assertIsInstance(res, list)
491
492    def test_hardware(self):
493        res = self._classes(all_collector_classes=default_collectors.collectors,
494                            gather_subset=['hardware'])
495        self.assertIsInstance(res, list)
496        self.assertIn(default_collectors.PlatformFactCollector, res)
497        self.assertIn(default_collectors.LinuxHardwareCollector, res)
498
499        self.assertTrue(res.index(default_collectors.LinuxHardwareCollector) >
500                        res.index(default_collectors.PlatformFactCollector))
501
502    def test_network(self):
503        res = self._classes(all_collector_classes=default_collectors.collectors,
504                            gather_subset=['network'])
505        self.assertIsInstance(res, list)
506        self.assertIn(default_collectors.DistributionFactCollector, res)
507        self.assertIn(default_collectors.PlatformFactCollector, res)
508        self.assertIn(default_collectors.LinuxNetworkCollector, res)
509
510        self.assertTrue(res.index(default_collectors.LinuxNetworkCollector) >
511                        res.index(default_collectors.PlatformFactCollector))
512        self.assertTrue(res.index(default_collectors.LinuxNetworkCollector) >
513                        res.index(default_collectors.DistributionFactCollector))
514
515        # self.assertEqual(set(res, [default_collectors.DistributionFactCollector,
516        #                       default_collectors.PlatformFactCollector,
517        #                       default_collectors.LinuxNetworkCollector])
518
519    def test_env(self):
520        res = self._classes(all_collector_classes=default_collectors.collectors,
521                            gather_subset=['env'])
522        self.assertIsInstance(res, list)
523        self.assertEqual(res, [default_collectors.EnvFactCollector])
524
525    def test_facter(self):
526        res = self._classes(all_collector_classes=default_collectors.collectors,
527                            gather_subset=set(['env', 'facter']))
528        self.assertIsInstance(res, list)
529        self.assertEqual(set(res),
530                         set([default_collectors.EnvFactCollector,
531                              default_collectors.FacterFactCollector]))
532
533    def test_facter_ohai(self):
534        res = self._classes(all_collector_classes=default_collectors.collectors,
535                            gather_subset=set(['env', 'facter', 'ohai']))
536        self.assertIsInstance(res, list)
537        self.assertEqual(set(res),
538                         set([default_collectors.EnvFactCollector,
539                              default_collectors.FacterFactCollector,
540                              default_collectors.OhaiFactCollector]))
541
542    def test_just_facter(self):
543        res = self._classes(all_collector_classes=default_collectors.collectors,
544                            gather_subset=set(['facter']))
545        self.assertIsInstance(res, list)
546        self.assertEqual(set(res),
547                         set([default_collectors.FacterFactCollector]))
548
549    def test_collector_specified_multiple_times(self):
550        res = self._classes(all_collector_classes=default_collectors.collectors,
551                            gather_subset=['platform', 'all', 'machine'])
552        self.assertIsInstance(res, list)
553        self.assertIn(default_collectors.PlatformFactCollector,
554                      res)
555
556    def test_unknown_collector(self):
557        # something claims 'unknown_collector' is a valid gather_subset, but there is
558        # no FactCollector mapped to 'unknown_collector'
559        self.assertRaisesRegexp(TypeError,
560                                r'Bad subset.*unknown_collector.*given to Ansible.*allowed\:.*all,.*env.*',
561                                self._classes,
562                                all_collector_classes=default_collectors.collectors,
563                                gather_subset=['env', 'unknown_collector'])
564