1# Copyright: (c) 2017, Ansible Project
2# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)
3
4# Make coding more python3-ish
5from __future__ import (absolute_import, division, print_function)
6__metaclass__ = type
7import pytest
8
9from jinja2 import Environment
10
11import ansible.plugins.filter.mathstuff as ms
12from ansible.errors import AnsibleFilterError, AnsibleFilterTypeError
13
14
15UNIQUE_DATA = (([1, 3, 4, 2], sorted([1, 2, 3, 4])),
16               ([1, 3, 2, 4, 2, 3], sorted([1, 2, 3, 4])),
17               (['a', 'b', 'c', 'd'], sorted(['a', 'b', 'c', 'd'])),
18               (['a', 'a', 'd', 'b', 'a', 'd', 'c', 'b'], sorted(['a', 'b', 'c', 'd'])),
19               )
20
21TWO_SETS_DATA = (([1, 2], [3, 4], ([], sorted([1, 2]), sorted([1, 2, 3, 4]), sorted([1, 2, 3, 4]))),
22                 ([1, 2, 3], [5, 3, 4], ([3], sorted([1, 2]), sorted([1, 2, 5, 4]), sorted([1, 2, 3, 4, 5]))),
23                 (['a', 'b', 'c'], ['d', 'c', 'e'], (['c'], sorted(['a', 'b']), sorted(['a', 'b', 'd', 'e']), sorted(['a', 'b', 'c', 'e', 'd']))),
24                 )
25
26env = Environment()
27
28
29@pytest.mark.parametrize('data, expected', UNIQUE_DATA)
30class TestUnique:
31    def test_unhashable(self, data, expected):
32        assert sorted(ms.unique(env, list(data))) == expected
33
34    def test_hashable(self, data, expected):
35        assert sorted(ms.unique(env, tuple(data))) == expected
36
37
38@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA)
39class TestIntersect:
40    def test_unhashable(self, dataset1, dataset2, expected):
41        assert sorted(ms.intersect(env, list(dataset1), list(dataset2))) == expected[0]
42
43    def test_hashable(self, dataset1, dataset2, expected):
44        assert sorted(ms.intersect(env, tuple(dataset1), tuple(dataset2))) == expected[0]
45
46
47@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA)
48class TestDifference:
49    def test_unhashable(self, dataset1, dataset2, expected):
50        assert sorted(ms.difference(env, list(dataset1), list(dataset2))) == expected[1]
51
52    def test_hashable(self, dataset1, dataset2, expected):
53        assert sorted(ms.difference(env, tuple(dataset1), tuple(dataset2))) == expected[1]
54
55
56@pytest.mark.parametrize('dataset1, dataset2, expected', TWO_SETS_DATA)
57class TestSymmetricDifference:
58    def test_unhashable(self, dataset1, dataset2, expected):
59        assert sorted(ms.symmetric_difference(env, list(dataset1), list(dataset2))) == expected[2]
60
61    def test_hashable(self, dataset1, dataset2, expected):
62        assert sorted(ms.symmetric_difference(env, tuple(dataset1), tuple(dataset2))) == expected[2]
63
64
65class TestMin:
66    def test_min(self):
67        assert ms.min((1, 2)) == 1
68        assert ms.min((2, 1)) == 1
69        assert ms.min(('p', 'a', 'w', 'b', 'p')) == 'a'
70
71
72class TestMax:
73    def test_max(self):
74        assert ms.max((1, 2)) == 2
75        assert ms.max((2, 1)) == 2
76        assert ms.max(('p', 'a', 'w', 'b', 'p')) == 'w'
77
78
79class TestLogarithm:
80    def test_log_non_number(self):
81        # Message changed in python3.6
82        with pytest.raises(AnsibleFilterTypeError, match='log\\(\\) can only be used on numbers: (a float is required|must be real number, not str)'):
83            ms.logarithm('a')
84        with pytest.raises(AnsibleFilterTypeError, match='log\\(\\) can only be used on numbers: (a float is required|must be real number, not str)'):
85            ms.logarithm(10, base='a')
86
87    def test_log_ten(self):
88        assert ms.logarithm(10, 10) == 1.0
89        assert ms.logarithm(69, 10) * 1000 // 1 == 1838
90
91    def test_log_natural(self):
92        assert ms.logarithm(69) * 1000 // 1 == 4234
93
94    def test_log_two(self):
95        assert ms.logarithm(69, 2) * 1000 // 1 == 6108
96
97
98class TestPower:
99    def test_power_non_number(self):
100        # Message changed in python3.6
101        with pytest.raises(AnsibleFilterTypeError, match='pow\\(\\) can only be used on numbers: (a float is required|must be real number, not str)'):
102            ms.power('a', 10)
103
104        with pytest.raises(AnsibleFilterTypeError, match='pow\\(\\) can only be used on numbers: (a float is required|must be real number, not str)'):
105            ms.power(10, 'a')
106
107    def test_power_squared(self):
108        assert ms.power(10, 2) == 100
109
110    def test_power_cubed(self):
111        assert ms.power(10, 3) == 1000
112
113
114class TestInversePower:
115    def test_root_non_number(self):
116        # Messages differed in python-2.6, python-2.7-3.5, and python-3.6+
117        with pytest.raises(AnsibleFilterTypeError, match="root\\(\\) can only be used on numbers:"
118                           " (invalid literal for float\\(\\): a"
119                           "|could not convert string to float: a"
120                           "|could not convert string to float: 'a')"):
121            ms.inversepower(10, 'a')
122
123        with pytest.raises(AnsibleFilterTypeError, match="root\\(\\) can only be used on numbers: (a float is required|must be real number, not str)"):
124            ms.inversepower('a', 10)
125
126    def test_square_root(self):
127        assert ms.inversepower(100) == 10
128        assert ms.inversepower(100, 2) == 10
129
130    def test_cube_root(self):
131        assert ms.inversepower(27, 3) == 3
132
133
134class TestRekeyOnMember():
135    # (Input data structure, member to rekey on, expected return)
136    VALID_ENTRIES = (
137        ([{"proto": "eigrp", "state": "enabled"}, {"proto": "ospf", "state": "enabled"}],
138         'proto',
139         {'eigrp': {'state': 'enabled', 'proto': 'eigrp'}, 'ospf': {'state': 'enabled', 'proto': 'ospf'}}),
140        ({'eigrp': {"proto": "eigrp", "state": "enabled"}, 'ospf': {"proto": "ospf", "state": "enabled"}},
141         'proto',
142         {'eigrp': {'state': 'enabled', 'proto': 'eigrp'}, 'ospf': {'state': 'enabled', 'proto': 'ospf'}}),
143    )
144
145    # (Input data structure, member to rekey on, expected error message)
146    INVALID_ENTRIES = (
147        # Fail when key is not found
148        (AnsibleFilterError, [{"proto": "eigrp", "state": "enabled"}], 'invalid_key', "Key invalid_key was not found"),
149        (AnsibleFilterError, {"eigrp": {"proto": "eigrp", "state": "enabled"}}, 'invalid_key', "Key invalid_key was not found"),
150        # Fail when key is duplicated
151        (AnsibleFilterError, [{"proto": "eigrp"}, {"proto": "ospf"}, {"proto": "ospf"}],
152         'proto', 'Key ospf is not unique, cannot correctly turn into dict'),
153        # Fail when value is not a dict
154        (AnsibleFilterTypeError, ["string"], 'proto', "List item is not a valid dict"),
155        (AnsibleFilterTypeError, [123], 'proto', "List item is not a valid dict"),
156        (AnsibleFilterTypeError, [[{'proto': 1}]], 'proto', "List item is not a valid dict"),
157        # Fail when we do not send a dict or list
158        (AnsibleFilterTypeError, "string", 'proto', "Type is not a valid list, set, or dict"),
159        (AnsibleFilterTypeError, 123, 'proto', "Type is not a valid list, set, or dict"),
160    )
161
162    @pytest.mark.parametrize("list_original, key, expected", VALID_ENTRIES)
163    def test_rekey_on_member_success(self, list_original, key, expected):
164        assert ms.rekey_on_member(list_original, key) == expected
165
166    @pytest.mark.parametrize("expected_exception_type, list_original, key, expected", INVALID_ENTRIES)
167    def test_fail_rekey_on_member(self, expected_exception_type, list_original, key, expected):
168        with pytest.raises(expected_exception_type) as err:
169            ms.rekey_on_member(list_original, key)
170
171        assert err.value.message == expected
172
173    def test_duplicate_strategy_overwrite(self):
174        list_original = ({'proto': 'eigrp', 'id': 1}, {'proto': 'ospf', 'id': 2}, {'proto': 'eigrp', 'id': 3})
175        expected = {'eigrp': {'proto': 'eigrp', 'id': 3}, 'ospf': {'proto': 'ospf', 'id': 2}}
176        assert ms.rekey_on_member(list_original, 'proto', duplicates='overwrite') == expected
177