1"""Test utilities."""
2from importlib import reload as reload_module
3import io
4import logging
5import multiprocessing
6from multiprocessing import synchronize
7import shutil
8import sys
9import tempfile
10from typing import Any
11from typing import Callable
12from typing import cast
13from typing import IO
14from typing import Iterable
15from typing import List
16from typing import Optional
17from typing import Union
18import unittest
19import warnings
20
21from cryptography.hazmat.backends import default_backend
22from cryptography.hazmat.primitives import serialization
23import josepy as jose
24from OpenSSL import crypto
25import pkg_resources
26
27from certbot import configuration
28from certbot import util
29from certbot._internal import constants
30from certbot._internal import lock
31from certbot._internal import storage
32from certbot._internal.display import obj as display_obj
33from certbot.compat import filesystem
34from certbot.compat import os
35from certbot.display import util as display_util
36from certbot.plugins import common
37
38try:
39    # When we remove this deprecated import, we should also remove the
40    # "external-mock" test environment and the mock dependency listed in
41    # tools/pinning/pyproject.toml.
42    import mock
43    warnings.warn(
44        "The external mock module is being used for backwards compatibility "
45        "since it is available, however, future versions of Certbot's tests will "
46        "use unittest.mock. Be sure to update your code accordingly.",
47        PendingDeprecationWarning
48    )
49except ImportError:  # pragma: no cover
50    from unittest import mock  # type: ignore
51
52
53class DummyInstaller(common.Installer):
54    """Dummy installer plugin for test purpose."""
55    def get_all_names(self) -> Iterable[str]:
56        pass
57
58    def deploy_cert(self, domain: str, cert_path: str, key_path: str, chain_path: str,
59                    fullchain_path: str) -> None:
60        pass
61
62    def enhance(self, domain: str, enhancement: str,
63                options: Optional[Union[List[str], str]] = None) -> None:
64        pass
65
66    def supported_enhancements(self) -> List[str]:
67        pass
68
69    def save(self, title: Optional[str] = None, temporary: bool = False) -> None:
70        pass
71
72    def config_test(self) -> None:
73        pass
74
75    def restart(self) -> None:
76        pass
77
78    @classmethod
79    def add_parser_arguments(cls, add: Callable[..., None]) -> None:
80        pass
81
82    def prepare(self) -> None:
83        pass
84
85    def more_info(self) -> str:
86        pass
87
88
89def vector_path(*names: str) -> str:
90    """Path to a test vector."""
91    return pkg_resources.resource_filename(
92        __name__, os.path.join('testdata', *names))
93
94
95def load_vector(*names: str) -> bytes:
96    """Load contents of a test vector."""
97    # luckily, resource_string opens file in binary mode
98    data = pkg_resources.resource_string(
99        __name__, os.path.join('testdata', *names))
100    # Try at most to convert CRLF to LF when data is text
101    try:
102        return data.decode().replace('\r\n', '\n').encode()
103    except ValueError:
104        # Failed to process the file with standard encoding.
105        # Most likely not a text file, return its bytes untouched.
106        return data
107
108
109def _guess_loader(filename: str, loader_pem: int, loader_der: int) -> int:
110    _, ext = os.path.splitext(filename)
111    if ext.lower() == '.pem':
112        return loader_pem
113    elif ext.lower() == '.der':
114        return loader_der
115    raise ValueError("Loader could not be recognized based on extension")  # pragma: no cover
116
117
118def load_cert(*names: str) -> crypto.X509:
119    """Load certificate."""
120    loader = _guess_loader(
121        names[-1], crypto.FILETYPE_PEM, crypto.FILETYPE_ASN1)
122    return crypto.load_certificate(loader, load_vector(*names))
123
124
125def load_csr(*names: str) -> crypto.X509Req:
126    """Load certificate request."""
127    loader = _guess_loader(
128        names[-1], crypto.FILETYPE_PEM, crypto.FILETYPE_ASN1)
129    return crypto.load_certificate_request(loader, load_vector(*names))
130
131
132def load_comparable_csr(*names: str) -> jose.ComparableX509:
133    """Load ComparableX509 certificate request."""
134    return jose.ComparableX509(load_csr(*names))
135
136
137def load_rsa_private_key(*names: str) -> jose.ComparableRSAKey:
138    """Load RSA private key."""
139    loader = _guess_loader(names[-1], crypto.FILETYPE_PEM, crypto.FILETYPE_ASN1)
140    loader_fn: Callable[..., Any]
141    if loader == crypto.FILETYPE_PEM:
142        loader_fn = serialization.load_pem_private_key
143    else:
144        loader_fn = serialization.load_der_private_key
145    return jose.ComparableRSAKey(loader_fn(
146        load_vector(*names), password=None, backend=default_backend()))
147
148
149def load_pyopenssl_private_key(*names: str) -> crypto.PKey:
150    """Load pyOpenSSL private key."""
151    loader = _guess_loader(
152        names[-1], crypto.FILETYPE_PEM, crypto.FILETYPE_ASN1)
153    return crypto.load_privatekey(loader, load_vector(*names))
154
155
156def make_lineage(config_dir: str, testfile: str, ec: bool = False) -> str:
157    """Creates a lineage defined by testfile.
158
159    This creates the archive, live, and renewal directories if
160    necessary and creates a simple lineage.
161
162    :param str config_dir: path to the configuration directory
163    :param str testfile: configuration file to base the lineage on
164    :param bool ec: True if we generate the lineage with an ECDSA key
165
166    :returns: path to the renewal conf file for the created lineage
167    :rtype: str
168
169    """
170    lineage_name = testfile[:-len('.conf')]
171
172    conf_dir = os.path.join(
173        config_dir, constants.RENEWAL_CONFIGS_DIR)
174    archive_dir = os.path.join(
175        config_dir, constants.ARCHIVE_DIR, lineage_name)
176    live_dir = os.path.join(
177        config_dir, constants.LIVE_DIR, lineage_name)
178
179    for directory in (archive_dir, conf_dir, live_dir,):
180        if not os.path.exists(directory):
181            filesystem.makedirs(directory)
182
183    sample_archive = vector_path('sample-archive{}'.format('-ec' if ec else ''))
184    for kind in os.listdir(sample_archive):
185        shutil.copyfile(os.path.join(sample_archive, kind),
186                        os.path.join(archive_dir, kind))
187
188    for kind in storage.ALL_FOUR:
189        os.symlink(os.path.join(archive_dir, '{0}1.pem'.format(kind)),
190                   os.path.join(live_dir, '{0}.pem'.format(kind)))
191
192    conf_path = os.path.join(config_dir, conf_dir, testfile)
193    with open(vector_path(testfile)) as src:
194        with open(conf_path, 'w') as dst:
195            dst.writelines(
196                line.replace('MAGICDIR', config_dir) for line in src)
197
198    return conf_path
199
200
201def patch_get_utility(target: str = 'zope.component.getUtility') -> mock.MagicMock:
202    """Deprecated, patch certbot.display.util directly or use patch_display_util instead.
203
204    :param str target: path to patch
205
206    :returns: mock zope.component.getUtility
207    :rtype: mock.MagicMock
208
209    """
210    warnings.warn('Decorator certbot.tests.util.patch_get_utility is deprecated. You should now '
211                  'patch certbot.display.util yourself directly or use '
212                  'certbot.tests.util.patch_display_util as a temporary workaround.')
213    return cast(mock.MagicMock, mock.patch(target, new_callable=_create_display_util_mock))
214
215
216def patch_get_utility_with_stdout(target: str = 'zope.component.getUtility',
217                                  stdout: Optional[IO] = None) -> mock.MagicMock:
218    """Deprecated, patch certbot.display.util directly
219    or use patch_display_util_with_stdout instead.
220
221    :param str target: path to patch
222    :param object stdout: object to write standard output to; it is
223        expected to have a `write` method
224
225    :returns: mock zope.component.getUtility
226    :rtype: mock.MagicMock
227
228    """
229    warnings.warn('Decorator certbot.tests.util.patch_get_utility_with_stdout is deprecated. You '
230                  'should now patch certbot.display.util yourself directly or use '
231                  'use certbot.tests.util.patch_display_util_with_stdout as a temporary '
232                  'workaround.')
233    stdout = stdout if stdout else io.StringIO()
234    freezable_mock = _create_display_util_mock_with_stdout(stdout)
235    return cast(mock.MagicMock, mock.patch(target, new=freezable_mock))
236
237
238def patch_display_util() -> mock.MagicMock:
239    """Patch certbot.display.util to use a special mock display utility.
240
241    The mock display utility works like a regular mock object, except it also
242    also asserts that methods are called with valid arguments.
243
244    The mock created by this patch mocks out Certbot internals so this can be
245    used like the old patch_get_utility function. That is, the mock object will
246    be called by the certbot.display.util functions and the mock returned by
247    that call will be used as the display utility. This was done to simplify
248    the transition from zope.component and mocking certbot.display.util
249    functions directly in test code should be preferred over using this
250    function in the future.
251
252    See https://github.com/certbot/certbot/issues/8948
253
254    :returns: patch on the function used internally by certbot.display.util to
255        get a display utility instance
256    :rtype: mock.MagicMock
257
258    """
259    return cast(mock.MagicMock, mock.patch('certbot._internal.display.obj.get_display',
260                                           new_callable=_create_display_util_mock))
261
262
263def patch_display_util_with_stdout(
264        stdout: Optional[IO] = None) -> mock.MagicMock:
265    """Patch certbot.display.util to use a special mock display utility.
266
267    The mock display utility works like a regular mock object, except it also
268    asserts that methods are called with valid arguments.
269
270    The mock created by this patch mocks out Certbot internals so this can be
271    used like the old patch_get_utility function. That is, the mock object will
272    be called by the certbot.display.util functions and the mock returned by
273    that call will be used as the display utility. This was done to simplify
274    the transition from zope.component and mocking certbot.display.util
275    functions directly in test code should be preferred over using this
276    function in the future.
277
278    See https://github.com/certbot/certbot/issues/8948
279
280    The `message` argument passed to the display utility methods is passed to
281    stdout's write method.
282
283    :param object stdout: object to write standard output to; it is
284        expected to have a `write` method
285    :returns: patch on the function used internally by certbot.display.util to
286        get a display utility instance
287    :rtype: mock.MagicMock
288
289    """
290    stdout = stdout if stdout else io.StringIO()
291
292    return cast(mock.MagicMock, mock.patch('certbot._internal.display.obj.get_display',
293                                           new=_create_display_util_mock_with_stdout(stdout)))
294
295
296class FreezableMock:
297    """Mock object with the ability to freeze attributes.
298
299    This class works like a regular mock.MagicMock object, except
300    attributes and behavior set before the object is frozen cannot
301    be changed during tests.
302
303    If a func argument is provided to the constructor, this function
304    is called first when an instance of FreezableMock is called,
305    followed by the usual behavior defined by MagicMock. The return
306    value of func is ignored.
307
308    """
309    def __init__(self, frozen: bool = False, func: Callable[..., Any] = None,
310                 return_value: Any = mock.sentinel.DEFAULT) -> None:
311        self._frozen_set = set() if frozen else {'freeze', }
312        self._func = func
313        self._mock = mock.MagicMock()
314        if return_value != mock.sentinel.DEFAULT:
315            self.return_value = return_value
316        self._frozen = frozen
317
318    def freeze(self) -> None:
319        """Freeze object preventing further changes."""
320        self._frozen = True
321
322    def __call__(self, *args: Any, **kwargs: Any) -> mock.MagicMock:
323        if self._func is not None:
324            self._func(*args, **kwargs)
325        return self._mock(*args, **kwargs)
326
327    def __getattribute__(self, name: str) -> Any:
328        if name == '_frozen':
329            try:
330                return object.__getattribute__(self, name)
331            except AttributeError:
332                return False
333        elif name in ('return_value', 'side_effect',):
334            return getattr(object.__getattribute__(self, '_mock'), name)
335        elif name == '_frozen_set' or name in self._frozen_set:
336            return object.__getattribute__(self, name)
337        else:
338            return getattr(object.__getattribute__(self, '_mock'), name)
339
340    def __setattr__(self, name: str, value: Any) -> None:
341        """ Before it is frozen, attributes are set on the FreezableMock
342        instance and added to the _frozen_set. Attributes in the _frozen_set
343        cannot be changed after the FreezableMock is frozen. In this case,
344        they are set on the underlying _mock.
345
346        In cases of return_value and side_effect, these attributes are always
347        passed through to the instance's _mock and added to the _frozen_set
348        before the object is frozen.
349
350        """
351        if self._frozen:
352            if name in self._frozen_set:
353                raise AttributeError('Cannot change frozen attribute ' + name)
354            return setattr(self._mock, name, value)
355
356        if name != '_frozen_set':
357            self._frozen_set.add(name)
358
359        if name in ('return_value', 'side_effect'):
360            return setattr(self._mock, name, value)
361
362        return object.__setattr__(self, name, value)
363
364
365def _create_display_util_mock() -> FreezableMock:
366    display = FreezableMock()
367    # Use pylint code for disable to keep on single line under line length limit
368    method_list = [func for func in dir(display_obj.FileDisplay)
369                   if callable(getattr(display_obj.FileDisplay, func))
370                   and not func.startswith("__")]
371    for method in method_list:
372        if method != 'notification':
373            frozen_mock = FreezableMock(frozen=True, func=_assert_valid_call)
374            setattr(display, method, frozen_mock)
375    display.freeze()
376    return FreezableMock(frozen=True, return_value=display)
377
378
379def _create_display_util_mock_with_stdout(stdout: IO) -> FreezableMock:
380    def _write_msg(message: str, *unused_args: Any, **unused_kwargs: Any) -> None:
381        """Write to message to stdout.
382        """
383        if message:
384            stdout.write(message)
385
386    def mock_method(*args: Any, **kwargs: Any) -> None:
387        """
388        Mock function for display utility methods.
389        """
390        _assert_valid_call(args, kwargs)
391        _write_msg(*args, **kwargs)
392
393    display = FreezableMock()
394    # Use pylint code for disable to keep on single line under line length limit
395    method_list = [func for func in dir(display_obj.FileDisplay)
396                   if callable(getattr(display_obj.FileDisplay, func))
397                   and not func.startswith("__")]
398    for method in method_list:
399        if method == 'notification':
400            frozen_mock = FreezableMock(frozen=True,
401                                        func=_write_msg)
402        else:
403            frozen_mock = FreezableMock(frozen=True,
404                                        func=mock_method)
405        setattr(display, method, frozen_mock)
406    display.freeze()
407    return FreezableMock(frozen=True, return_value=display)
408
409
410def _assert_valid_call(*args: Any, **kwargs: Any) -> None:
411    assert_args = [args[0] if args else kwargs['message']]
412
413    assert_kwargs = {}
414    assert_kwargs['default'] = kwargs.get('default', None)
415    assert_kwargs['cli_flag'] = kwargs.get('cli_flag', None)
416    assert_kwargs['force_interactive'] = kwargs.get('force_interactive', False)
417
418    display_util.assert_valid_call(*assert_args, **assert_kwargs)
419
420
421class TempDirTestCase(unittest.TestCase):
422    """Base test class which sets up and tears down a temporary directory"""
423
424    def setUp(self) -> None:
425        """Execute before test"""
426        self.tempdir = tempfile.mkdtemp()
427
428    def tearDown(self) -> None:
429        """Execute after test"""
430        # Cleanup opened resources after a test. This is usually done through atexit handlers in
431        # Certbot, but during tests, atexit will not run registered functions before tearDown is
432        # called and instead will run them right before the entire test process exits.
433        # It is a problem on Windows, that does not accept to clean resources before closing them.
434        logging.shutdown()
435        # Remove logging handlers that have been closed so they won't be
436        # accidentally used in future tests.
437        logging.getLogger().handlers = []
438        util._release_locks()  # pylint: disable=protected-access
439
440        shutil.rmtree(self.tempdir)
441
442
443class ConfigTestCase(TempDirTestCase):
444    """Test class which sets up a NamespaceConfig object."""
445    def setUp(self) -> None:
446        super().setUp()
447        self.config = configuration.NamespaceConfig(
448            mock.MagicMock(**constants.CLI_DEFAULTS)
449        )
450        self.config.namespace.verb = "certonly"
451        self.config.namespace.config_dir = os.path.join(self.tempdir, 'config')
452        self.config.namespace.work_dir = os.path.join(self.tempdir, 'work')
453        self.config.namespace.logs_dir = os.path.join(self.tempdir, 'logs')
454        self.config.namespace.cert_path = constants.CLI_DEFAULTS['auth_cert_path']
455        self.config.namespace.fullchain_path = constants.CLI_DEFAULTS['auth_chain_path']
456        self.config.namespace.chain_path = constants.CLI_DEFAULTS['auth_chain_path']
457        self.config.namespace.server = "https://example.com"
458
459
460def _handle_lock(event_in: synchronize.Event, event_out: synchronize.Event, path: str) -> None:
461    """
462    Acquire a file lock on given path, then wait to release it. This worker is coordinated
463    using events to signal when the lock should be acquired and released.
464    :param multiprocessing.Event event_in: event object to signal when to release the lock
465    :param multiprocessing.Event event_out: event object to signal when the lock is acquired
466    :param path: the path to lock
467    """
468    if os.path.isdir(path):
469        my_lock = lock.lock_dir(path)
470    else:
471        my_lock = lock.LockFile(path)
472    try:
473        event_out.set()
474        assert event_in.wait(timeout=20), 'Timeout while waiting to release the lock.'
475    finally:
476        my_lock.release()
477
478
479def lock_and_call(callback: Callable[[], Any], path_to_lock: str) -> None:
480    """
481    Grab a lock on path_to_lock from a foreign process then execute the callback.
482    :param callable callback: object to call after acquiring the lock
483    :param str path_to_lock: path to file or directory to lock
484    """
485    # Reload certbot.util module to reset internal _LOCKS dictionary.
486    reload_module(util)
487
488    emit_event = multiprocessing.Event()
489    receive_event = multiprocessing.Event()
490    process = multiprocessing.Process(target=_handle_lock,
491                                      args=(emit_event, receive_event, path_to_lock))
492    process.start()
493
494    # Wait confirmation that lock is acquired
495    assert receive_event.wait(timeout=10), 'Timeout while waiting to acquire the lock.'
496    # Execute the callback
497    callback()
498    # Trigger unlock from foreign process
499    emit_event.set()
500
501    # Wait for process termination
502    process.join(timeout=10)
503    assert process.exitcode == 0
504
505
506def skip_on_windows(reason: str) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
507    """Decorator to skip permanently a test on Windows. A reason is required."""
508    def wrapper(function: Callable[..., Any]) -> Callable[..., Any]:
509        """Wrapped version"""
510        return unittest.skipIf(sys.platform == 'win32', reason)(function)
511    return wrapper
512
513
514def temp_join(path: str) -> str:
515    """
516    Return the given path joined to the tempdir path for the current platform
517    Eg.: 'cert' => /tmp/cert (Linux) or 'C:\\Users\\currentuser\\AppData\\Temp\\cert' (Windows)
518    """
519    return os.path.join(tempfile.gettempdir(), path)
520