1"""Tests for certbot.plugins.common."""
2import functools
3import shutil
4import unittest
5
6import josepy as jose
7try:
8    import mock
9except ImportError: # pragma: no cover
10    from unittest import mock
11
12from acme import challenges
13from certbot import achallenges
14from certbot import crypto_util
15from certbot import errors
16from certbot.compat import filesystem
17from certbot.compat import os
18from certbot.tests import acme_util
19from certbot.tests import util as test_util
20
21AUTH_KEY = jose.JWKRSA.load(test_util.load_vector("rsa512_key.pem"))
22ACHALL = achallenges.KeyAuthorizationAnnotatedChallenge(
23            challb=acme_util.chall_to_challb(challenges.HTTP01(token=b'token1'),
24                                             "pending"),
25            domain="encryption-example.demo", account_key=AUTH_KEY)
26
27
28class NamespaceFunctionsTest(unittest.TestCase):
29    """Tests for certbot.plugins.common.*_namespace functions."""
30
31    def test_option_namespace(self):
32        from certbot.plugins.common import option_namespace
33        self.assertEqual("foo-", option_namespace("foo"))
34
35    def test_dest_namespace(self):
36        from certbot.plugins.common import dest_namespace
37        self.assertEqual("foo_", dest_namespace("foo"))
38
39    def test_dest_namespace_with_dashes(self):
40        from certbot.plugins.common import dest_namespace
41        self.assertEqual("foo_bar_", dest_namespace("foo-bar"))
42
43
44class PluginTest(unittest.TestCase):
45    """Test for certbot.plugins.common.Plugin."""
46
47    def setUp(self):
48        from certbot.plugins.common import Plugin
49
50        class MockPlugin(Plugin):  # pylint: disable=missing-docstring
51            def prepare(self) -> None:
52                pass
53
54            def more_info(self) -> str:
55                pass
56
57            @classmethod
58            def add_parser_arguments(cls, add):
59                add("foo-bar", dest="different_to_foo_bar", x=1, y=None)
60
61        self.plugin_cls = MockPlugin
62        self.config = mock.MagicMock()
63        self.plugin = MockPlugin(config=self.config, name="mock")
64
65    def test_init(self):
66        self.assertEqual("mock", self.plugin.name)
67        self.assertEqual(self.config, self.plugin.config)
68
69    def test_option_namespace(self):
70        self.assertEqual("mock-", self.plugin.option_namespace)
71
72    def test_option_name(self):
73        self.assertEqual("mock-foo_bar", self.plugin.option_name("foo_bar"))
74
75    def test_dest_namespace(self):
76        self.assertEqual("mock_", self.plugin.dest_namespace)
77
78    def test_dest(self):
79        self.assertEqual("mock_foo_bar", self.plugin.dest("foo-bar"))
80        self.assertEqual("mock_foo_bar", self.plugin.dest("foo_bar"))
81
82    def test_conf(self):
83        self.assertEqual(self.config.mock_foo_bar, self.plugin.conf("foo-bar"))
84
85    def test_inject_parser_options(self):
86        parser = mock.MagicMock()
87        self.plugin_cls.inject_parser_options(parser, "mock")
88        # note that inject_parser_options doesn't check if dest has
89        # correct prefix
90        parser.add_argument.assert_called_once_with(
91            "--mock-foo-bar", dest="different_to_foo_bar", x=1, y=None)
92
93    def test_fallback_auth_hint(self):
94        self.assertIn("the mock plugin completed the required dns-01 challenges",
95                      self.plugin.auth_hint([acme_util.DNS01_A, acme_util.DNS01_A]))
96        self.assertIn("the mock plugin completed the required dns-01 and http-01 challenges",
97                      self.plugin.auth_hint([acme_util.DNS01_A, acme_util.HTTP01_A,
98                                             acme_util.DNS01_A]))
99
100
101class InstallerTest(test_util.ConfigTestCase):
102    """Tests for certbot.plugins.common.Installer."""
103
104    def setUp(self):
105        super().setUp()
106        filesystem.mkdir(self.config.config_dir)
107        from certbot.tests.util import DummyInstaller
108
109        self.installer = DummyInstaller(config=self.config,
110                                   name="Installer")
111        self.reverter = self.installer.reverter
112
113    def test_add_to_real_checkpoint(self):
114        files = {"foo.bar", "baz.qux",}
115        save_notes = "foo bar baz qux"
116        self._test_wrapped_method("add_to_checkpoint", files, save_notes)
117
118    def test_add_to_real_checkpoint2(self):
119        self._test_add_to_checkpoint_common(False)
120
121    def test_add_to_temporary_checkpoint(self):
122        self._test_add_to_checkpoint_common(True)
123
124    def _test_add_to_checkpoint_common(self, temporary):
125        files = {"foo.bar", "baz.qux",}
126        save_notes = "foo bar baz qux"
127
128        installer_func = functools.partial(self.installer.add_to_checkpoint,
129                                           temporary=temporary)
130
131        if temporary:
132            reverter_func_name = "add_to_temp_checkpoint"
133        else:
134            reverter_func_name = "add_to_checkpoint"
135
136        self._test_adapted_method(installer_func, reverter_func_name, files, save_notes)
137
138    def test_finalize_checkpoint(self):
139        self._test_wrapped_method("finalize_checkpoint", "foo")
140
141    def test_recovery_routine(self):
142        self._test_wrapped_method("recovery_routine")
143
144    def test_revert_temporary_config(self):
145        self._test_wrapped_method("revert_temporary_config")
146
147    def test_rollback_checkpoints(self):
148        self._test_wrapped_method("rollback_checkpoints", 42)
149
150    def _test_wrapped_method(self, name, *args, **kwargs):
151        """Test a wrapped reverter method.
152
153        :param str name: name of the method to test
154        :param tuple args: position arguments to method
155        :param dict kwargs: keyword arguments to method
156
157        """
158        installer_func = getattr(self.installer, name)
159        self._test_adapted_method(installer_func, name, *args, **kwargs)
160
161    def _test_adapted_method(self, installer_func,
162                             reverter_func_name, *passed_args, **passed_kwargs):
163        """Test an adapted reverter method
164
165        :param callable installer_func: installer method to test
166        :param str reverter_func_name: name of the method on the
167            reverter that should be called
168        :param tuple passed_args: positional arguments passed from
169            installer method to the reverter method
170        :param dict passed_kargs: keyword arguments passed from
171            installer method to the reverter method
172
173        """
174        with mock.patch.object(self.reverter, reverter_func_name) as reverter_func:
175            installer_func(*passed_args, **passed_kwargs)
176            reverter_func.assert_called_once_with(*passed_args, **passed_kwargs)
177            reverter_func.side_effect = errors.ReverterError
178            self.assertRaises(
179                errors.PluginError, installer_func, *passed_args, **passed_kwargs)
180
181    def test_install_ssl_dhparams(self):
182        self.installer.install_ssl_dhparams()
183        self.assertTrue(os.path.isfile(self.installer.ssl_dhparams))
184
185    def _current_ssl_dhparams_hash(self):
186        from certbot._internal.constants import SSL_DHPARAMS_SRC
187        return crypto_util.sha256sum(SSL_DHPARAMS_SRC)
188
189    def test_current_file_hash_in_all_hashes(self):
190        from certbot._internal.constants import ALL_SSL_DHPARAMS_HASHES
191        self.assertIn(self._current_ssl_dhparams_hash(), ALL_SSL_DHPARAMS_HASHES,
192            "Constants.ALL_SSL_DHPARAMS_HASHES must be appended"
193            " with the sha256 hash of self.config.ssl_dhparams when it is updated.")
194
195
196class AddrTest(unittest.TestCase):
197    """Tests for certbot.plugins.common.Addr."""
198
199    def setUp(self):
200        from certbot.plugins.common import Addr
201        self.addr1 = Addr.fromstring("192.168.1.1")
202        self.addr2 = Addr.fromstring("192.168.1.1:*")
203        self.addr3 = Addr.fromstring("192.168.1.1:80")
204        self.addr4 = Addr.fromstring("[fe00::1]")
205        self.addr5 = Addr.fromstring("[fe00::1]:*")
206        self.addr6 = Addr.fromstring("[fe00::1]:80")
207        self.addr7 = Addr.fromstring("[fe00::1]:5")
208        self.addr8 = Addr.fromstring("[fe00:1:2:3:4:5:6:7:8:9]:8080")
209
210    def test_fromstring(self):
211        self.assertEqual(self.addr1.get_addr(), "192.168.1.1")
212        self.assertEqual(self.addr1.get_port(), "")
213        self.assertEqual(self.addr2.get_addr(), "192.168.1.1")
214        self.assertEqual(self.addr2.get_port(), "*")
215        self.assertEqual(self.addr3.get_addr(), "192.168.1.1")
216        self.assertEqual(self.addr3.get_port(), "80")
217        self.assertEqual(self.addr4.get_addr(), "[fe00::1]")
218        self.assertEqual(self.addr4.get_port(), "")
219        self.assertEqual(self.addr5.get_addr(), "[fe00::1]")
220        self.assertEqual(self.addr5.get_port(), "*")
221        self.assertEqual(self.addr6.get_addr(), "[fe00::1]")
222        self.assertEqual(self.addr6.get_port(), "80")
223        self.assertEqual(self.addr6.get_ipv6_exploded(),
224                         "fe00:0:0:0:0:0:0:1")
225        self.assertEqual(self.addr1.get_ipv6_exploded(),
226                         "")
227        self.assertEqual(self.addr7.get_port(), "5")
228        self.assertEqual(self.addr8.get_ipv6_exploded(),
229                         "fe00:1:2:3:4:5:6:7")
230
231    def test_str(self):
232        self.assertEqual(str(self.addr1), "192.168.1.1")
233        self.assertEqual(str(self.addr2), "192.168.1.1:*")
234        self.assertEqual(str(self.addr3), "192.168.1.1:80")
235        self.assertEqual(str(self.addr4), "[fe00::1]")
236        self.assertEqual(str(self.addr5), "[fe00::1]:*")
237        self.assertEqual(str(self.addr6), "[fe00::1]:80")
238
239    def test_get_addr_obj(self):
240        self.assertEqual(str(self.addr1.get_addr_obj("443")), "192.168.1.1:443")
241        self.assertEqual(str(self.addr2.get_addr_obj("")), "192.168.1.1")
242        self.assertEqual(str(self.addr1.get_addr_obj("*")), "192.168.1.1:*")
243        self.assertEqual(str(self.addr4.get_addr_obj("443")), "[fe00::1]:443")
244        self.assertEqual(str(self.addr5.get_addr_obj("")), "[fe00::1]")
245        self.assertEqual(str(self.addr4.get_addr_obj("*")), "[fe00::1]:*")
246
247    def test_eq(self):
248        self.assertEqual(self.addr1, self.addr2.get_addr_obj(""))
249        self.assertNotEqual(self.addr1, self.addr2)
250        self.assertNotEqual(self.addr1, 3333)
251
252        self.assertEqual(self.addr4, self.addr4.get_addr_obj(""))
253        self.assertNotEqual(self.addr4, self.addr5)
254        self.assertNotEqual(self.addr4, 3333)
255        from certbot.plugins.common import Addr
256        self.assertEqual(self.addr4, Addr.fromstring("[fe00:0:0::1]"))
257        self.assertEqual(self.addr4, Addr.fromstring("[fe00:0::0:0:1]"))
258
259
260    def test_set_inclusion(self):
261        from certbot.plugins.common import Addr
262        set_a = {self.addr1, self.addr2}
263        addr1b = Addr.fromstring("192.168.1.1")
264        addr2b = Addr.fromstring("192.168.1.1:*")
265        set_b = {addr1b, addr2b}
266
267        self.assertEqual(set_a, set_b)
268
269        set_c = {self.addr4, self.addr5}
270        addr4b = Addr.fromstring("[fe00::1]")
271        addr5b = Addr.fromstring("[fe00::1]:*")
272        set_d = {addr4b, addr5b}
273
274        self.assertEqual(set_c, set_d)
275
276
277class ChallengePerformerTest(unittest.TestCase):
278    """Tests for certbot.plugins.common.ChallengePerformer."""
279
280    def setUp(self):
281        configurator = mock.MagicMock()
282
283        from certbot.plugins.common import ChallengePerformer
284        self.performer = ChallengePerformer(configurator)
285
286    def test_add_chall(self):
287        self.performer.add_chall(ACHALL, 0)
288        self.assertEqual(1, len(self.performer.achalls))
289        self.assertEqual([0], self.performer.indices)
290
291    def test_perform(self):
292        self.assertRaises(NotImplementedError, self.performer.perform)
293
294
295class InstallVersionControlledFileTest(test_util.TempDirTestCase):
296    """Tests for certbot.plugins.common.install_version_controlled_file."""
297
298    def setUp(self):
299        super().setUp()
300        self.hashes = ["someotherhash"]
301        self.dest_path = os.path.join(self.tempdir, "options-ssl-dest.conf")
302        self.hash_path = os.path.join(self.tempdir, ".options-ssl-conf.txt")
303        self.old_path = os.path.join(self.tempdir, "options-ssl-old.conf")
304        self.source_path = os.path.join(self.tempdir, "options-ssl-src.conf")
305        for path in (self.source_path, self.old_path,):
306            with open(path, "w") as f:
307                f.write(path)
308            self.hashes.append(crypto_util.sha256sum(path))
309
310    def _call(self):
311        from certbot.plugins.common import install_version_controlled_file
312        install_version_controlled_file(self.dest_path,
313                                        self.hash_path,
314                                        self.source_path,
315                                        self.hashes)
316
317    def _current_file_hash(self):
318        return crypto_util.sha256sum(self.source_path)
319
320    def _assert_current_file(self):
321        self.assertTrue(os.path.isfile(self.dest_path))
322        self.assertEqual(crypto_util.sha256sum(self.dest_path),
323            self._current_file_hash())
324
325    def test_no_file(self):
326        self.assertFalse(os.path.isfile(self.dest_path))
327        self._call()
328        self._assert_current_file()
329
330    def test_current_file(self):
331        # 1st iteration installs the file, the 2nd checks if it needs updating
332        for _ in range(2):
333            self._call()
334            self._assert_current_file()
335
336    def test_prev_file_updates_to_current(self):
337        shutil.copyfile(self.old_path, self.dest_path)
338        self._call()
339        self._assert_current_file()
340
341    def test_manually_modified_current_file_does_not_update(self):
342        self._call()
343        with open(self.dest_path, "a") as mod_ssl_conf:
344            mod_ssl_conf.write("a new line for the wrong hash\n")
345        with mock.patch("certbot.plugins.common.logger") as mock_logger:
346            self._call()
347            self.assertIs(mock_logger.warning.called, False)
348        self.assertTrue(os.path.isfile(self.dest_path))
349        self.assertEqual(crypto_util.sha256sum(self.source_path),
350            self._current_file_hash())
351        self.assertNotEqual(crypto_util.sha256sum(self.dest_path),
352            self._current_file_hash())
353
354    def test_manually_modified_past_file_warns(self):
355        with open(self.dest_path, "a") as mod_ssl_conf:
356            mod_ssl_conf.write("a new line for the wrong hash\n")
357        with open(self.hash_path, "w") as f:
358            f.write("hashofanoldversion")
359        with mock.patch("certbot.plugins.common.logger") as mock_logger:
360            self._call()
361            self.assertEqual(mock_logger.warning.call_args[0][0],
362                "%s has been manually modified; updated file "
363                "saved to %s. We recommend updating %s for security purposes.")
364        self.assertEqual(crypto_util.sha256sum(self.source_path),
365            self._current_file_hash())
366        # only print warning once
367        with mock.patch("certbot.plugins.common.logger") as mock_logger:
368            self._call()
369            self.assertIs(mock_logger.warning.called, False)
370
371if __name__ == "__main__":
372    unittest.main()  # pragma: no cover
373