1from __future__ import unicode_literals
2
3import inspect
4import sys
5
6import pytest
7
8from _pytest_mock_version import version
9
10__version__ = version
11
12# pseudo-six; if this starts to require more than this, depend on six already
13if sys.version_info[0] == 2:  # pragma: no cover
14    text_type = unicode  # noqa
15else:
16    text_type = str
17
18
19def _get_mock_module(config):
20    """
21    Import and return the actual "mock" module. By default this is "mock" for Python 2 and
22    "unittest.mock" for Python 3, but the user can force to always use "mock" on Python 3 using
23    the mock_use_standalone_module ini option.
24    """
25    if not hasattr(_get_mock_module, "_module"):
26        use_standalone_module = parse_ini_boolean(
27            config.getini("mock_use_standalone_module")
28        )
29        if sys.version_info[0] == 2 or use_standalone_module:
30            import mock
31
32            _get_mock_module._module = mock
33        else:
34            import unittest.mock
35
36            _get_mock_module._module = unittest.mock
37
38    return _get_mock_module._module
39
40
41class MockFixture(object):
42    """
43    Fixture that provides the same interface to functions in the mock module,
44    ensuring that they are uninstalled at the end of each test.
45    """
46
47    def __init__(self, config):
48        self._patches = []  # list of mock._patch objects
49        self._mocks = []  # list of MagicMock objects
50        self.mock_module = mock_module = _get_mock_module(config)
51        self.patch = self._Patcher(self._patches, self._mocks, mock_module)
52        # aliases for convenience
53        self.Mock = mock_module.Mock
54        self.MagicMock = mock_module.MagicMock
55        self.NonCallableMock = mock_module.NonCallableMock
56        self.PropertyMock = mock_module.PropertyMock
57        self.call = mock_module.call
58        self.ANY = mock_module.ANY
59        self.DEFAULT = mock_module.DEFAULT
60        self.create_autospec = mock_module.create_autospec
61        self.sentinel = mock_module.sentinel
62        self.mock_open = mock_module.mock_open
63
64    def resetall(self):
65        """
66        Call reset_mock() on all patchers started by this fixture.
67        """
68        for m in self._mocks:
69            m.reset_mock()
70
71    def stopall(self):
72        """
73        Stop all patchers started by this fixture. Can be safely called multiple
74        times.
75        """
76        for p in reversed(self._patches):
77            p.stop()
78        self._patches[:] = []
79        self._mocks[:] = []
80
81    def spy(self, obj, name):
82        """
83        Creates a spy of method. It will run method normally, but it is now
84        possible to use `mock` call features with it, like call count.
85
86        :param object obj: An object.
87        :param unicode name: A method in object.
88        :rtype: mock.MagicMock
89        :return: Spy object.
90        """
91        method = getattr(obj, name)
92
93        autospec = inspect.ismethod(method) or inspect.isfunction(method)
94        # Can't use autospec classmethod or staticmethod objects
95        # see: https://bugs.python.org/issue23078
96        if inspect.isclass(obj):
97            # Bypass class descriptor:
98            # http://stackoverflow.com/questions/14187973/python3-check-if-method-is-static
99            try:
100                value = obj.__getattribute__(obj, name)
101            except AttributeError:
102                pass
103            else:
104                if isinstance(value, (classmethod, staticmethod)):
105                    autospec = False
106
107        result = self.patch.object(obj, name, side_effect=method, autospec=autospec)
108        return result
109
110    def stub(self, name=None):
111        """
112        Creates a stub method. It accepts any arguments. Ideal to register to
113        callbacks in tests.
114
115        :param name: the constructed stub's name as used in repr
116        :rtype: mock.MagicMock
117        :return: Stub object.
118        """
119        return self.mock_module.MagicMock(spec=lambda *args, **kwargs: None, name=name)
120
121    class _Patcher(object):
122        """
123        Object to provide the same interface as mock.patch, mock.patch.object,
124        etc. We need this indirection to keep the same API of the mock package.
125        """
126
127        def __init__(self, patches, mocks, mock_module):
128            self._patches = patches
129            self._mocks = mocks
130            self.mock_module = mock_module
131
132        def _start_patch(self, mock_func, *args, **kwargs):
133            """Patches something by calling the given function from the mock
134            module, registering the patch to stop it later and returns the
135            mock object resulting from the mock call.
136            """
137            p = mock_func(*args, **kwargs)
138            mocked = p.start()
139            self._patches.append(p)
140            if hasattr(mocked, "reset_mock"):
141                self._mocks.append(mocked)
142            return mocked
143
144        def object(self, *args, **kwargs):
145            """API to mock.patch.object"""
146            return self._start_patch(self.mock_module.patch.object, *args, **kwargs)
147
148        def multiple(self, *args, **kwargs):
149            """API to mock.patch.multiple"""
150            return self._start_patch(self.mock_module.patch.multiple, *args, **kwargs)
151
152        def dict(self, *args, **kwargs):
153            """API to mock.patch.dict"""
154            return self._start_patch(self.mock_module.patch.dict, *args, **kwargs)
155
156        def __call__(self, *args, **kwargs):
157            """API to mock.patch"""
158            return self._start_patch(self.mock_module.patch, *args, **kwargs)
159
160
161@pytest.yield_fixture
162def mocker(pytestconfig):
163    """
164    return an object that has the same interface to the `mock` module, but
165    takes care of automatically undoing all patches after each test method.
166    """
167    result = MockFixture(pytestconfig)
168    yield result
169    result.stopall()
170
171
172@pytest.fixture
173def mock(mocker):
174    """
175    Same as "mocker", but kept only for backward compatibility.
176    """
177    import warnings
178
179    warnings.warn(
180        '"mock" fixture has been deprecated, use "mocker" instead', DeprecationWarning
181    )
182    return mocker
183
184
185_mock_module_patches = []
186_mock_module_originals = {}
187
188
189def assert_wrapper(__wrapped_mock_method__, *args, **kwargs):
190    __tracebackhide__ = True
191    try:
192        __wrapped_mock_method__(*args, **kwargs)
193        return
194    except AssertionError as e:
195        if getattr(e, "_mock_introspection_applied", 0):
196            msg = text_type(e)
197        else:
198            __mock_self = args[0]
199            msg = text_type(e)
200            if __mock_self.call_args is not None:
201                actual_args, actual_kwargs = __mock_self.call_args
202                msg += "\n\npytest introspection follows:\n"
203                try:
204                    assert actual_args == args[1:]
205                except AssertionError as e:
206                    msg += "\nArgs:\n" + text_type(e)
207                try:
208                    assert actual_kwargs == kwargs
209                except AssertionError as e:
210                    msg += "\nKwargs:\n" + text_type(e)
211    e = AssertionError(msg)
212    e._mock_introspection_applied = True
213    raise e
214
215
216def wrap_assert_not_called(*args, **kwargs):
217    __tracebackhide__ = True
218    assert_wrapper(_mock_module_originals["assert_not_called"], *args, **kwargs)
219
220
221def wrap_assert_called_with(*args, **kwargs):
222    __tracebackhide__ = True
223    assert_wrapper(_mock_module_originals["assert_called_with"], *args, **kwargs)
224
225
226def wrap_assert_called_once(*args, **kwargs):
227    __tracebackhide__ = True
228    assert_wrapper(_mock_module_originals["assert_called_once"], *args, **kwargs)
229
230
231def wrap_assert_called_once_with(*args, **kwargs):
232    __tracebackhide__ = True
233    assert_wrapper(_mock_module_originals["assert_called_once_with"], *args, **kwargs)
234
235
236def wrap_assert_has_calls(*args, **kwargs):
237    __tracebackhide__ = True
238    assert_wrapper(_mock_module_originals["assert_has_calls"], *args, **kwargs)
239
240
241def wrap_assert_any_call(*args, **kwargs):
242    __tracebackhide__ = True
243    assert_wrapper(_mock_module_originals["assert_any_call"], *args, **kwargs)
244
245
246def wrap_assert_called(*args, **kwargs):
247    __tracebackhide__ = True
248    assert_wrapper(_mock_module_originals["assert_called"], *args, **kwargs)
249
250
251def wrap_assert_methods(config):
252    """
253    Wrap assert methods of mock module so we can hide their traceback and
254    add introspection information to specified argument asserts.
255    """
256    # Make sure we only do this once
257    if _mock_module_originals:
258        return
259
260    mock_module = _get_mock_module(config)
261
262    wrappers = {
263        "assert_called": wrap_assert_called,
264        "assert_called_once": wrap_assert_called_once,
265        "assert_called_with": wrap_assert_called_with,
266        "assert_called_once_with": wrap_assert_called_once_with,
267        "assert_any_call": wrap_assert_any_call,
268        "assert_has_calls": wrap_assert_has_calls,
269        "assert_not_called": wrap_assert_not_called,
270    }
271    for method, wrapper in wrappers.items():
272        try:
273            original = getattr(mock_module.NonCallableMock, method)
274        except AttributeError:  # pragma: no cover
275            continue
276        _mock_module_originals[method] = original
277        patcher = mock_module.patch.object(mock_module.NonCallableMock, method, wrapper)
278        patcher.start()
279        _mock_module_patches.append(patcher)
280
281    if hasattr(config, "add_cleanup"):
282        add_cleanup = config.add_cleanup
283    else:
284        # pytest 2.7 compatibility
285        add_cleanup = config._cleanup.append
286    add_cleanup(unwrap_assert_methods)
287
288
289def unwrap_assert_methods():
290    for patcher in _mock_module_patches:
291        try:
292            patcher.stop()
293        except RuntimeError as e:
294            # a patcher might have been stopped by user code (#137)
295            # so we need to catch this error here and ignore it;
296            # unfortunately there's no public API to check if a patch
297            # has been started, so catching the error it is
298            if text_type(e) == "stop called on unstarted patcher":
299                pass
300            else:
301                raise
302    _mock_module_patches[:] = []
303    _mock_module_originals.clear()
304
305
306def pytest_addoption(parser):
307    parser.addini(
308        "mock_traceback_monkeypatch",
309        "Monkeypatch the mock library to improve reporting of the "
310        "assert_called_... methods",
311        default=True,
312    )
313    parser.addini(
314        "mock_use_standalone_module",
315        'Use standalone "mock" (from PyPI) instead of builtin "unittest.mock" '
316        "on Python 3",
317        default=False,
318    )
319
320
321def parse_ini_boolean(value):
322    if value in (True, False):
323        return value
324    try:
325        return {"true": True, "false": False}[value.lower()]
326    except KeyError:
327        raise ValueError("unknown string for bool: %r" % value)
328
329
330def pytest_configure(config):
331    tb = config.getoption("--tb", default="auto")
332    if (
333        parse_ini_boolean(config.getini("mock_traceback_monkeypatch"))
334        and tb != "native"
335    ):
336        wrap_assert_methods(config)
337