1from __future__ import absolute_import
2
3import shutil
4import os
5import functools
6
7from mock import MagicMock, patch, call
8
9
10def rm_f(path):
11    try:
12        # Assume it's a directory
13        shutil.rmtree(path, ignore_errors=True)
14    except OSError:
15        # Directory delete failed, so it's likely a file
16        os.remove(path)
17
18
19def track(**mocks):
20    tracker = MagicMock()
21
22    for name, mocker in mocks.items():
23        tracker.attach_mock(mocker, name)
24
25    return tracker
26
27
28def intercept(obj, methodname, wrapper):
29    """
30    Wraps an existing method on an object with the provided generator, which
31    will be "sent" the value when it yields control.
32
33    ::
34
35        >>> def ensure_primary_key_is_set():
36        ...     assert model.pk is None
37        ...     saved = yield
38        ...     aasert model is saved
39        ...     assert model.pk is not None
40        ...
41        >>> intercept(model, 'save', ensure_primary_key_is_set)
42        >>> model.save()
43
44    :param obj: the object that has the method to be wrapped
45    :type obj: :class:`object`
46    :param methodname: the name of the method that will be wrapped
47    :type methodname: :class:`str`
48    :param wrapper: the wrapper
49    :type wrapper: generator callable
50    """
51    original = getattr(obj, methodname)
52
53    def replacement(*args, **kwargs):
54        wrapfn = wrapper(*args, **kwargs)
55        wrapfn.send(None)
56        result = original(*args, **kwargs)
57        try:
58            wrapfn.send(result)
59        except StopIteration:
60            return result
61        else:
62            raise AssertionError('Generator did not stop')
63
64    def unwrap():
65        """
66        Restores the method to it's original (unwrapped) state.
67        """
68        setattr(obj, methodname, original)
69
70    replacement.unwrap = unwrap
71
72    setattr(obj, methodname, replacement)
73
74
75class mock_import(patch.dict):
76
77    FROM_X_GET_Y = lambda s, x, y: getattr(x, y)
78
79    def __init__(self, path):
80        self.mock = MagicMock()
81        self.path = path
82        self.modules = {self.base: self.mock}
83
84        for i in range(len(self.remainder)):
85            tail_parts = self.remainder[0:i + 1]
86            key = '.'.join([self.base] + tail_parts)
87            reduction = functools.reduce(self.FROM_X_GET_Y,
88                                         tail_parts, self.mock)
89            self.modules[key] = reduction
90
91        super(mock_import, self).__init__('sys.modules', self.modules)
92
93    @property
94    def base(self):
95        return self.path.split('.')[0]
96
97    @property
98    def remainder(self):
99        return self.path.split('.')[1:]
100
101    def __enter__(self):
102        super(mock_import, self).__enter__()
103        return self.modules[self.path]
104
105    def __call__(self, func):
106        super(mock_import, self).__call__(func)
107
108        @functools.wraps(func)
109        def inner(*args, **kwargs):
110            args = list(args)
111            args.insert(1, self.modules[self.path])
112
113            with self:
114                func(*args, **kwargs)
115
116        return inner
117
118
119class effect(list):
120    """
121    Helper class that is itself callable, whose return values when called are
122    configured via the tuples passed in to the constructor. Useful to build
123    ``side_effect`` callables for Mock objects.  Raises TypeError if
124    called with arguments that it was not configured with:
125
126    >>> from exam.objects import call, effect
127    >>> side_effect = effect((call(1), 'with 1'), (call(2), 'with 2'))
128    >>> side_effect(1)
129    'with 1'
130    >>> side_effect(2)
131    'with 2'
132
133    Call argument equality is checked via equality (==)
134    of the ``call``` object, which is the 0th item of the configuration
135    tuple passed in to the ``effect`` constructor.
136    By default, ``call`` objects are just ``mock.call`` objects.
137
138    If you would like to customize this behavior,
139    subclass `effect` and redefine your own `call_class`
140    class variable.  I.e.
141
142        class myeffect(effect):
143            call_class = my_call_class
144    """
145
146    call_class = call
147
148    def __init__(self, *calls):
149        """
150        :param calls: Two-item tuple containing call and the return value.
151        :type calls: :class:`effect.call_class`
152        """
153        super(effect, self).__init__(calls)
154
155    def __call__(self, *args, **kwargs):
156        this_call = self.call_class(*args, **kwargs)
157
158        for call_obj, return_value in self:
159            if call_obj == this_call:
160                return return_value
161
162        raise TypeError('Unknown effect for: %r, %r' % (args, kwargs))
163