1# Copyright 2008 Google Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15# This is a fork of the pymox library intended to work with Python 3.
16# The file was modified by quermit@gmail.com and dawid.fatyga@gmail.com
17
18"""Mox, an object-mocking framework for Python.
19
20Mox works in the record-replay-verify paradigm.  When you first create
21a mock object, it is in record mode.  You then programmatically set
22the expected behavior of the mock object (what methods are to be
23called on it, with what parameters, what they should return, and in
24what order).
25
26Once you have set up the expected mock behavior, you put it in replay
27mode.  Now the mock responds to method calls just as you told it to.
28If an unexpected method (or an expected method with unexpected
29parameters) is called, then an exception will be raised.
30
31Once you are done interacting with the mock, you need to verify that
32all the expected interactions occured.  (Maybe your code exited
33prematurely without calling some cleanup method!)  The verify phase
34ensures that every expected method was called; otherwise, an exception
35will be raised.
36
37WARNING! Mock objects created by Mox are not thread-safe.  If you are
38call a mock in multiple threads, it should be guarded by a mutex.
39
40TODO(stevepm): Add the option to make mocks thread-safe!
41
42Suggested usage / workflow:
43
44    # Create Mox factory
45    my_mox = Mox()
46
47    # Create a mock data access object
48    mock_dao = my_mox.CreateMock(DAOClass)
49
50    # Set up expected behavior
51    mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
52    mock_dao.DeletePerson(person)
53
54    # Put mocks in replay mode
55    my_mox.ReplayAll()
56
57    # Inject mock object and run test
58    controller.SetDao(mock_dao)
59    controller.DeletePersonById('1')
60
61    # Verify all methods were called as expected
62    my_mox.VerifyAll()
63"""
64
65import collections
66import difflib
67import inspect
68import re
69import types
70import unittest
71
72from mox3 import stubout
73
74
75class Error(AssertionError):
76    """Base exception for this module."""
77
78    pass
79
80
81class ExpectedMethodCallsError(Error):
82    """Raised when an expected method wasn't called.
83
84    This can occur if Verify() is called before all expected methods have been
85    called.
86    """
87
88    def __init__(self, expected_methods):
89        """Init exception.
90
91        Args:
92            # expected_methods: A sequence of MockMethod objects that should
93            #                   have been called.
94            expected_methods: [MockMethod]
95
96        Raises:
97            ValueError: if expected_methods contains no methods.
98        """
99
100        if not expected_methods:
101            raise ValueError("There must be at least one expected method")
102        Error.__init__(self)
103        self._expected_methods = expected_methods
104
105    def __str__(self):
106        calls = "\n".join(["%3d.  %s" % (i, m)
107                          for i, m in enumerate(self._expected_methods)])
108        return "Verify: Expected methods never called:\n%s" % (calls,)
109
110
111class UnexpectedMethodCallError(Error):
112    """Raised when an unexpected method is called.
113
114    This can occur if a method is called with incorrect parameters, or out of
115    the specified order.
116    """
117
118    def __init__(self, unexpected_method, expected):
119        """Init exception.
120
121        Args:
122            # unexpected_method: MockMethod that was called but was not at the
123            #     head of the expected_method queue.
124            # expected: MockMethod or UnorderedGroup the method should have
125            #     been in.
126            unexpected_method: MockMethod
127            expected: MockMethod or UnorderedGroup
128        """
129
130        Error.__init__(self)
131        if expected is None:
132            self._str = "Unexpected method call %s" % (unexpected_method,)
133        else:
134            differ = difflib.Differ()
135            diff = differ.compare(str(unexpected_method).splitlines(True),
136                                  str(expected).splitlines(True))
137            self._str = ("Unexpected method call."
138                         "  unexpected:-  expected:+\n%s"
139                         % ("\n".join(line.rstrip() for line in diff),))
140
141    def __str__(self):
142        return self._str
143
144
145class UnknownMethodCallError(Error):
146    """Raised if an unknown method is requested of the mock object."""
147
148    def __init__(self, unknown_method_name):
149        """Init exception.
150
151        Args:
152            # unknown_method_name: Method call that is not part of the mocked
153            #     class's public interface.
154            unknown_method_name: str
155        """
156
157        Error.__init__(self)
158        self._unknown_method_name = unknown_method_name
159
160    def __str__(self):
161        return ("Method called is not a member of the object: %s" %
162                self._unknown_method_name)
163
164
165class PrivateAttributeError(Error):
166    """Raised if a MockObject is passed a private additional attribute name."""
167
168    def __init__(self, attr):
169        Error.__init__(self)
170        self._attr = attr
171
172    def __str__(self):
173        return ("Attribute '%s' is private and should not be available"
174                "in a mock object." % self._attr)
175
176
177class ExpectedMockCreationError(Error):
178    """Raised if mocks should have been created by StubOutClassWithMocks."""
179
180    def __init__(self, expected_mocks):
181        """Init exception.
182
183        Args:
184            # expected_mocks: A sequence of MockObjects that should have been
185            #     created
186
187        Raises:
188            ValueError: if expected_mocks contains no methods.
189        """
190
191        if not expected_mocks:
192            raise ValueError("There must be at least one expected method")
193        Error.__init__(self)
194        self._expected_mocks = expected_mocks
195
196    def __str__(self):
197        mocks = "\n".join(["%3d.  %s" % (i, m)
198                          for i, m in enumerate(self._expected_mocks)])
199        return "Verify: Expected mocks never created:\n%s" % (mocks,)
200
201
202class UnexpectedMockCreationError(Error):
203    """Raised if too many mocks were created by StubOutClassWithMocks."""
204
205    def __init__(self, instance, *params, **named_params):
206        """Init exception.
207
208        Args:
209            # instance: the type of obejct that was created
210            # params: parameters given during instantiation
211            # named_params: named parameters given during instantiation
212        """
213
214        Error.__init__(self)
215        self._instance = instance
216        self._params = params
217        self._named_params = named_params
218
219    def __str__(self):
220        args = ", ".join(["%s" % v for i, v in enumerate(self._params)])
221        error = "Unexpected mock creation: %s(%s" % (self._instance, args)
222
223        if self._named_params:
224            error += ", " + ", ".join(["%s=%s" % (k, v) for k, v in
225                                      self._named_params.items()])
226
227        error += ")"
228        return error
229
230
231class Mox(object):
232    """Mox: a factory for creating mock objects."""
233
234    # A list of types that should be stubbed out with MockObjects (as
235    # opposed to MockAnythings).
236    _USE_MOCK_OBJECT = [types.FunctionType, types.ModuleType, types.MethodType]
237
238    def __init__(self):
239        """Initialize a new Mox."""
240
241        self._mock_objects = []
242        self.stubs = stubout.StubOutForTesting()
243
244    def CreateMock(self, class_to_mock, attrs=None, bounded_to=None):
245        """Create a new mock object.
246
247        Args:
248            # class_to_mock: the class to be mocked
249            class_to_mock: class
250            attrs: dict of attribute names to values that will be
251                   set on the mock object. Only public attributes may be set.
252            bounded_to: optionally, when class_to_mock is not a class,
253                        it points to a real class object, to which
254                        attribute is bound
255
256        Returns:
257            MockObject that can be used as the class_to_mock would be.
258        """
259        if attrs is None:
260            attrs = {}
261        new_mock = MockObject(class_to_mock, attrs=attrs,
262                              class_to_bind=bounded_to)
263        self._mock_objects.append(new_mock)
264        return new_mock
265
266    def CreateMockAnything(self, description=None):
267        """Create a mock that will accept any method calls.
268
269        This does not enforce an interface.
270
271        Args:
272        description: str. Optionally, a descriptive name for the mock object
273        being created, for debugging output purposes.
274        """
275        new_mock = MockAnything(description=description)
276        self._mock_objects.append(new_mock)
277        return new_mock
278
279    def ReplayAll(self):
280        """Set all mock objects to replay mode."""
281
282        for mock_obj in self._mock_objects:
283            mock_obj._Replay()
284
285    def VerifyAll(self):
286        """Call verify on all mock objects created."""
287
288        for mock_obj in self._mock_objects:
289            mock_obj._Verify()
290
291    def ResetAll(self):
292        """Call reset on all mock objects.    This does not unset stubs."""
293
294        for mock_obj in self._mock_objects:
295            mock_obj._Reset()
296
297    def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
298        """Replace a method, attribute, etc. with a Mock.
299
300        This will replace a class or module with a MockObject, and everything
301        else (method, function, etc) with a MockAnything. This can be
302        overridden to always use a MockAnything by setting use_mock_anything
303        to True.
304
305        Args:
306            obj: A Python object (class, module, instance, callable).
307            attr_name: str. The name of the attribute to replace with a mock.
308            use_mock_anything: bool. True if a MockAnything should be used
309                               regardless of the type of attribute.
310        """
311
312        if inspect.isclass(obj):
313            class_to_bind = obj
314        else:
315            class_to_bind = None
316
317        attr_to_replace = getattr(obj, attr_name)
318        attr_type = type(attr_to_replace)
319
320        if attr_type == MockAnything or attr_type == MockObject:
321            raise TypeError('Cannot mock a MockAnything! Did you remember to '
322                            'call UnsetStubs in your previous test?')
323
324        type_check = (
325            attr_type in self._USE_MOCK_OBJECT or
326            inspect.isclass(attr_to_replace) or
327            isinstance(attr_to_replace, object))
328        if type_check and not use_mock_anything:
329            stub = self.CreateMock(attr_to_replace, bounded_to=class_to_bind)
330        else:
331            stub = self.CreateMockAnything(
332                description='Stub for %s' % attr_to_replace)
333            stub.__name__ = attr_name
334
335        self.stubs.Set(obj, attr_name, stub)
336
337    def StubOutClassWithMocks(self, obj, attr_name):
338        """Replace a class with a "mock factory" that will create mock objects.
339
340        This is useful if the code-under-test directly instantiates
341        dependencies.    Previously some boilder plate was necessary to
342        create a mock that would act as a factory.    Using
343        StubOutClassWithMocks, once you've stubbed out the class you may
344        use the stubbed class as you would any other mock created by mox:
345        during the record phase, new mock instances will be created, and
346        during replay, the recorded mocks will be returned.
347
348        In replay mode
349
350        # Example using StubOutWithMock (the old, clunky way):
351
352        mock1 = mox.CreateMock(my_import.FooClass)
353        mock2 = mox.CreateMock(my_import.FooClass)
354        foo_factory = mox.StubOutWithMock(my_import, 'FooClass',
355                                          use_mock_anything=True)
356        foo_factory(1, 2).AndReturn(mock1)
357        foo_factory(9, 10).AndReturn(mock2)
358        mox.ReplayAll()
359
360        my_import.FooClass(1, 2)     # Returns mock1 again.
361        my_import.FooClass(9, 10)    # Returns mock2 again.
362        mox.VerifyAll()
363
364        # Example using StubOutClassWithMocks:
365
366        mox.StubOutClassWithMocks(my_import, 'FooClass')
367        mock1 = my_import.FooClass(1, 2)     # Returns a new mock of FooClass
368        mock2 = my_import.FooClass(9, 10)    # Returns another mock instance
369        mox.ReplayAll()
370
371        my_import.FooClass(1, 2)     # Returns mock1 again.
372        my_import.FooClass(9, 10)    # Returns mock2 again.
373        mox.VerifyAll()
374        """
375        attr_to_replace = getattr(obj, attr_name)
376        attr_type = type(attr_to_replace)
377
378        if attr_type == MockAnything or attr_type == MockObject:
379            raise TypeError('Cannot mock a MockAnything! Did you remember to '
380                            'call UnsetStubs in your previous test?')
381
382        if not inspect.isclass(attr_to_replace):
383            raise TypeError('Given attr is not a Class. Use StubOutWithMock.')
384
385        factory = _MockObjectFactory(attr_to_replace, self)
386        self._mock_objects.append(factory)
387        self.stubs.Set(obj, attr_name, factory)
388
389    def UnsetStubs(self):
390        """Restore stubs to their original state."""
391
392        self.stubs.UnsetAll()
393
394
395def Replay(*args):
396    """Put mocks into Replay mode.
397
398    Args:
399        # args is any number of mocks to put into replay mode.
400    """
401
402    for mock in args:
403        mock._Replay()
404
405
406def Verify(*args):
407    """Verify mocks.
408
409    Args:
410        # args is any number of mocks to be verified.
411    """
412
413    for mock in args:
414        mock._Verify()
415
416
417def Reset(*args):
418    """Reset mocks.
419
420    Args:
421        # args is any number of mocks to be reset.
422    """
423
424    for mock in args:
425        mock._Reset()
426
427
428class MockAnything(object):
429    """A mock that can be used to mock anything.
430
431    This is helpful for mocking classes that do not provide a public interface.
432    """
433
434    def __init__(self, description=None):
435        """Initialize a new MockAnything.
436
437        Args:
438            description: str. Optionally, a descriptive name for the mock
439                         object being created, for debugging output purposes.
440        """
441        self._description = description
442        self._Reset()
443
444    def __repr__(self):
445        if self._description:
446            return '<MockAnything instance of %s>' % self._description
447        else:
448            return '<MockAnything instance>'
449
450    def __getattr__(self, method_name):
451        """Intercept method calls on this object.
452
453         A new MockMethod is returned that is aware of the MockAnything's
454         state (record or replay).    The call will be recorded or replayed
455         by the MockMethod's __call__.
456
457        Args:
458            # method name: the name of the method being called.
459            method_name: str
460
461        Returns:
462            A new MockMethod aware of MockAnything's state (record or replay).
463        """
464        if method_name == '__dir__':
465                return self.__class__.__dir__.__get__(self, self.__class__)
466
467        return self._CreateMockMethod(method_name)
468
469    def __str__(self):
470        return self._CreateMockMethod('__str__')()
471
472    def __call__(self, *args, **kwargs):
473        return self._CreateMockMethod('__call__')(*args, **kwargs)
474
475    def __getitem__(self, i):
476        return self._CreateMockMethod('__getitem__')(i)
477
478    def _CreateMockMethod(self, method_name, method_to_mock=None,
479                          class_to_bind=object):
480        """Create a new mock method call and return it.
481
482        Args:
483            # method_name: the name of the method being called.
484            # method_to_mock: The actual method being mocked, used for
485            #                 introspection.
486            # class_to_bind: Class to which method is bounded
487            #                (object by default)
488            method_name: str
489            method_to_mock: a method object
490
491        Returns:
492            A new MockMethod aware of MockAnything's state (record or replay).
493        """
494
495        return MockMethod(method_name, self._expected_calls_queue,
496                          self._replay_mode, method_to_mock=method_to_mock,
497                          description=self._description,
498                          class_to_bind=class_to_bind)
499
500    def __nonzero__(self):
501        """Return 1 for nonzero so the mock can be used as a conditional."""
502
503        return 1
504
505    def __bool__(self):
506        """Return True for nonzero so the mock can be used as a conditional."""
507        return True
508
509    def __eq__(self, rhs):
510        """Provide custom logic to compare objects."""
511
512        return (isinstance(rhs, MockAnything) and
513                self._replay_mode == rhs._replay_mode and
514                self._expected_calls_queue == rhs._expected_calls_queue)
515
516    def __ne__(self, rhs):
517        """Provide custom logic to compare objects."""
518
519        return not self == rhs
520
521    def _Replay(self):
522        """Start replaying expected method calls."""
523
524        self._replay_mode = True
525
526    def _Verify(self):
527        """Verify that all of the expected calls have been made.
528
529        Raises:
530            ExpectedMethodCallsError: if there are still more method calls in
531                                      the expected queue.
532        """
533
534        # If the list of expected calls is not empty, raise an exception
535        if self._expected_calls_queue:
536            # The last MultipleTimesGroup is not popped from the queue.
537            if (len(self._expected_calls_queue) == 1 and
538                    isinstance(self._expected_calls_queue[0],
539                               MultipleTimesGroup) and
540                    self._expected_calls_queue[0].IsSatisfied()):
541                pass
542            else:
543                raise ExpectedMethodCallsError(self._expected_calls_queue)
544
545    def _Reset(self):
546        """Reset the state of this mock to record mode with an empty queue."""
547
548        # Maintain a list of method calls we are expecting
549        self._expected_calls_queue = collections.deque()
550
551        # Make sure we are in setup mode, not replay mode
552        self._replay_mode = False
553
554
555class MockObject(MockAnything):
556    """Mock object that simulates the public/protected interface of a class."""
557
558    def __init__(self, class_to_mock, attrs=None, class_to_bind=None):
559        """Initialize a mock object.
560
561        Determines the methods and properties of the class and stores them.
562
563        Args:
564            # class_to_mock: class to be mocked
565            class_to_mock: class
566            attrs: dict of attribute names to values that will be set on the
567                   mock object. Only public attributes may be set.
568            class_to_bind: optionally, when class_to_mock is not a class at
569                           all, it points to a real class
570
571        Raises:
572            PrivateAttributeError: if a supplied attribute is not public.
573            ValueError: if an attribute would mask an existing method.
574        """
575        if attrs is None:
576            attrs = {}
577
578        # Used to hack around the mixin/inheritance of MockAnything, which
579        # is not a proper object (it can be anything. :-)
580        MockAnything.__dict__['__init__'](self)
581
582        # Get a list of all the public and special methods we should mock.
583        self._known_methods = set()
584        self._known_vars = set()
585        self._class_to_mock = class_to_mock
586
587        if inspect.isclass(class_to_mock):
588            self._class_to_bind = self._class_to_mock
589        else:
590            self._class_to_bind = class_to_bind
591
592        try:
593            if inspect.isclass(self._class_to_mock):
594                self._description = class_to_mock.__name__
595            else:
596                self._description = type(class_to_mock).__name__
597        except Exception:
598            pass
599
600        for method in dir(class_to_mock):
601            attr = getattr(class_to_mock, method)
602            if callable(attr):
603                self._known_methods.add(method)
604            elif not (type(attr) is property):
605                # treating properties as class vars makes little sense.
606                self._known_vars.add(method)
607
608        # Set additional attributes at instantiation time; this is quicker
609        # than manually setting attributes that are normally created in
610        # __init__.
611        for attr, value in attrs.items():
612            if attr.startswith("_"):
613                raise PrivateAttributeError(attr)
614            elif attr in self._known_methods:
615                raise ValueError("'%s' is a method of '%s' objects." % (attr,
616                                 class_to_mock))
617            else:
618                setattr(self, attr, value)
619
620    def _CreateMockMethod(self, *args, **kwargs):
621        """Overridden to provide self._class_to_mock to class_to_bind."""
622        kwargs.setdefault("class_to_bind", self._class_to_bind)
623        return super(MockObject, self)._CreateMockMethod(*args, **kwargs)
624
625    def __getattr__(self, name):
626        """Intercept attribute request on this object.
627
628        If the attribute is a public class variable, it will be returned and
629        not recorded as a call.
630
631        If the attribute is not a variable, it is handled like a method
632        call. The method name is checked against the set of mockable
633        methods, and a new MockMethod is returned that is aware of the
634        MockObject's state (record or replay).    The call will be recorded
635        or replayed by the MockMethod's __call__.
636
637        Args:
638            # name: the name of the attribute being requested.
639            name: str
640
641        Returns:
642            Either a class variable or a new MockMethod that is aware of the
643            state of the mock (record or replay).
644
645        Raises:
646            UnknownMethodCallError if the MockObject does not mock the
647            requested method.
648        """
649
650        if name in self._known_vars:
651            return getattr(self._class_to_mock, name)
652
653        if name in self._known_methods:
654            return self._CreateMockMethod(
655                name,
656                method_to_mock=getattr(self._class_to_mock, name))
657
658        raise UnknownMethodCallError(name)
659
660    def __eq__(self, rhs):
661        """Provide custom logic to compare objects."""
662
663        return (isinstance(rhs, MockObject) and
664                self._class_to_mock == rhs._class_to_mock and
665                self._replay_mode == rhs._replay_mode and
666                self._expected_calls_queue == rhs._expected_calls_queue)
667
668    def __setitem__(self, key, value):
669        """Custom logic for mocking classes that support item assignment.
670
671        Args:
672            key: Key to set the value for.
673            value: Value to set.
674
675        Returns:
676            Expected return value in replay mode. A MockMethod object for the
677            __setitem__ method that has already been called if not in replay
678            mode.
679
680        Raises:
681            TypeError if the underlying class does not support item assignment.
682            UnexpectedMethodCallError if the object does not expect the call to
683                __setitem__.
684
685        """
686        # Verify the class supports item assignment.
687        if '__setitem__' not in dir(self._class_to_mock):
688            raise TypeError('object does not support item assignment')
689
690        # If we are in replay mode then simply call the mock __setitem__ method
691        if self._replay_mode:
692            return MockMethod('__setitem__', self._expected_calls_queue,
693                              self._replay_mode)(key, value)
694
695        # Otherwise, create a mock method __setitem__.
696        return self._CreateMockMethod('__setitem__')(key, value)
697
698    def __getitem__(self, key):
699        """Provide custom logic for mocking classes that are subscriptable.
700
701        Args:
702            key: Key to return the value for.
703
704        Returns:
705            Expected return value in replay mode. A MockMethod object for the
706            __getitem__ method that has already been called if not in replay
707            mode.
708
709        Raises:
710            TypeError if the underlying class is not subscriptable.
711            UnexpectedMethodCallError if the object does not expect the call to
712                __getitem__.
713
714        """
715        # Verify the class supports item assignment.
716        if '__getitem__' not in dir(self._class_to_mock):
717            raise TypeError('unsubscriptable object')
718
719        # If we are in replay mode then simply call the mock __getitem__ method
720        if self._replay_mode:
721            return MockMethod('__getitem__', self._expected_calls_queue,
722                              self._replay_mode)(key)
723
724        # Otherwise, create a mock method __getitem__.
725        return self._CreateMockMethod('__getitem__')(key)
726
727    def __iter__(self):
728        """Provide custom logic for mocking classes that are iterable.
729
730        Returns:
731            Expected return value in replay mode. A MockMethod object for the
732            __iter__ method that has already been called if not in replay mode.
733
734        Raises:
735            TypeError if the underlying class is not iterable.
736            UnexpectedMethodCallError if the object does not expect the call to
737                __iter__.
738
739        """
740        methods = dir(self._class_to_mock)
741
742        # Verify the class supports iteration.
743        if '__iter__' not in methods:
744            # If it doesn't have iter method and we are in replay method,
745            # then try to iterate using subscripts.
746            if '__getitem__' not in methods or not self._replay_mode:
747                raise TypeError('not iterable object')
748            else:
749                results = []
750                index = 0
751                try:
752                    while True:
753                        results.append(self[index])
754                        index += 1
755                except IndexError:
756                    return iter(results)
757
758        # If we are in replay mode then simply call the mock __iter__ method.
759        if self._replay_mode:
760            return MockMethod('__iter__', self._expected_calls_queue,
761                              self._replay_mode)()
762
763        # Otherwise, create a mock method __iter__.
764        return self._CreateMockMethod('__iter__')()
765
766    def __contains__(self, key):
767        """Provide custom logic for mocking classes that contain items.
768
769        Args:
770            key: Key to look in container for.
771
772        Returns:
773            Expected return value in replay mode. A MockMethod object for the
774            __contains__ method that has already been called if not in replay
775            mode.
776
777        Raises:
778            TypeError if the underlying class does not implement __contains__
779            UnexpectedMethodCaller if the object does not expect the call to
780            __contains__.
781
782        """
783        contains = self._class_to_mock.__dict__.get('__contains__', None)
784
785        if contains is None:
786            raise TypeError('unsubscriptable object')
787
788        if self._replay_mode:
789            return MockMethod('__contains__', self._expected_calls_queue,
790                              self._replay_mode)(key)
791
792        return self._CreateMockMethod('__contains__')(key)
793
794    def __call__(self, *params, **named_params):
795        """Provide custom logic for mocking classes that are callable."""
796
797        # Verify the class we are mocking is callable.
798        is_callable = hasattr(self._class_to_mock, '__call__')
799        if not is_callable:
800            raise TypeError('Not callable')
801
802        # Because the call is happening directly on this object instead of
803        # a method, the call on the mock method is made right here
804
805        # If we are mocking a Function, then use the function, and not the
806        # __call__ method
807        method = None
808        if type(self._class_to_mock) in (types.FunctionType, types.MethodType):
809            method = self._class_to_mock
810        else:
811            method = getattr(self._class_to_mock, '__call__')
812        mock_method = self._CreateMockMethod('__call__', method_to_mock=method)
813
814        return mock_method(*params, **named_params)
815
816    @property
817    def __name__(self):
818        """Return the name that is being mocked."""
819        return self._description
820
821    # TODO(dejw): this property stopped to work after I introduced changes with
822    #     binding classes. Fortunately I found a solution in the form of
823    #     __getattribute__ method below, but this issue should be investigated
824    @property
825    def __class__(self):
826        return self._class_to_mock
827
828    def __dir__(self):
829        """Return only attributes of a class to mock."""
830        return dir(self._class_to_mock)
831
832    def __getattribute__(self, name):
833        """Return _class_to_mock on __class__ attribute."""
834        if name == "__class__":
835            return super(MockObject, self).__getattribute__("_class_to_mock")
836
837        return super(MockObject, self).__getattribute__(name)
838
839
840class _MockObjectFactory(MockObject):
841    """A MockObjectFactory creates mocks and verifies __init__ params.
842
843    A MockObjectFactory removes the boiler plate code that was previously
844    necessary to stub out direction instantiation of a class.
845
846    The MockObjectFactory creates new MockObjects when called and verifies the
847    __init__ params are correct when in record mode.    When replaying,
848    existing mocks are returned, and the __init__ params are verified.
849
850    See StubOutWithMock vs StubOutClassWithMocks for more detail.
851    """
852
853    def __init__(self, class_to_mock, mox_instance):
854        MockObject.__init__(self, class_to_mock)
855        self._mox = mox_instance
856        self._instance_queue = collections.deque()
857
858    def __call__(self, *params, **named_params):
859        """Instantiate and record that a new mock has been created."""
860
861        method = getattr(self._class_to_mock, '__init__')
862        mock_method = self._CreateMockMethod('__init__', method_to_mock=method)
863        # Note: calling mock_method() is deferred in order to catch the
864        # empty instance_queue first.
865
866        if self._replay_mode:
867            if not self._instance_queue:
868                raise UnexpectedMockCreationError(self._class_to_mock, *params,
869                                                  **named_params)
870
871            mock_method(*params, **named_params)
872
873            return self._instance_queue.pop()
874        else:
875            mock_method(*params, **named_params)
876
877            instance = self._mox.CreateMock(self._class_to_mock)
878            self._instance_queue.appendleft(instance)
879            return instance
880
881    def _Verify(self):
882        """Verify that all mocks have been created."""
883        if self._instance_queue:
884            raise ExpectedMockCreationError(self._instance_queue)
885        super(_MockObjectFactory, self)._Verify()
886
887
888class MethodSignatureChecker(object):
889    """Ensures that methods are called correctly."""
890
891    _NEEDED, _DEFAULT, _GIVEN = range(3)
892
893    def __init__(self, method, class_to_bind=None):
894        """Creates a checker.
895
896        Args:
897            # method: A method to check.
898            # class_to_bind: optionally, a class used to type check first
899            #                method parameter, only used with unbound methods
900            method: function
901            class_to_bind: type or None
902
903        Raises:
904            ValueError: method could not be inspected, so checks aren't
905                        possible. Some methods and functions like built-ins
906                        can't be inspected.
907        """
908        try:
909            self._args, varargs, varkw, defaults = inspect.getargspec(method)
910        except TypeError:
911            raise ValueError('Could not get argument specification for %r'
912                             % (method,))
913        if (inspect.ismethod(method) or class_to_bind or (
914                hasattr(self, '_args') and len(self._args) > 0 and
915                self._args[0] == 'self')):
916            self._args = self._args[1:]    # Skip 'self'.
917        self._method = method
918        self._instance = None    # May contain the instance this is bound to.
919        self._instance = getattr(method, "__self__", None)
920
921        # _bounded_to determines whether the method is bound or not
922        if self._instance:
923            self._bounded_to = self._instance.__class__
924        else:
925            self._bounded_to = class_to_bind or getattr(method, "im_class",
926                                                        None)
927
928        self._has_varargs = varargs is not None
929        self._has_varkw = varkw is not None
930        if defaults is None:
931            self._required_args = self._args
932            self._default_args = []
933        else:
934            self._required_args = self._args[:-len(defaults)]
935            self._default_args = self._args[-len(defaults):]
936
937    def _RecordArgumentGiven(self, arg_name, arg_status):
938        """Mark an argument as being given.
939
940        Args:
941            # arg_name: The name of the argument to mark in arg_status.
942            # arg_status: Maps argument names to one of
943            #             _NEEDED, _DEFAULT, _GIVEN.
944            arg_name: string
945            arg_status: dict
946
947        Raises:
948            AttributeError: arg_name is already marked as _GIVEN.
949        """
950        if arg_status.get(arg_name, None) == MethodSignatureChecker._GIVEN:
951            raise AttributeError('%s provided more than once' % (arg_name,))
952        arg_status[arg_name] = MethodSignatureChecker._GIVEN
953
954    def Check(self, params, named_params):
955        """Ensures that the parameters used while recording a call are valid.
956
957        Args:
958            # params: A list of positional parameters.
959            # named_params: A dict of named parameters.
960            params: list
961            named_params: dict
962
963        Raises:
964            AttributeError: the given parameters don't work with the given
965                            method.
966        """
967        arg_status = dict((a, MethodSignatureChecker._NEEDED)
968                          for a in self._required_args)
969        for arg in self._default_args:
970            arg_status[arg] = MethodSignatureChecker._DEFAULT
971
972        # WARNING: Suspect hack ahead.
973        #
974        # Check to see if this is an unbound method, where the instance
975        # should be bound as the first argument.    We try to determine if
976        # the first argument (param[0]) is an instance of the class, or it
977        # is equivalent to the class (used to account for Comparators).
978        #
979        # NOTE: If a Func() comparator is used, and the signature is not
980        # correct, this will cause extra executions of the function.
981        if inspect.ismethod(self._method) or self._bounded_to:
982            # The extra param accounts for the bound instance.
983            if len(params) > len(self._required_args):
984                expected = self._bounded_to
985
986                # Check if the param is an instance of the expected class,
987                # or check equality (useful for checking Comparators).
988
989                # This is a hack to work around the fact that the first
990                # parameter can be a Comparator, and the comparison may raise
991                # an exception during this comparison, which is OK.
992                try:
993                    param_equality = (params[0] == expected)
994                except Exception:
995                    param_equality = False
996
997                if isinstance(params[0], expected) or param_equality:
998                    params = params[1:]
999                # If the IsA() comparator is being used, we need to check the
1000                # inverse of the usual case - that the given instance is a
1001                # subclass of the expected class. For example, the code under
1002                # test does late binding to a subclass.
1003                elif (isinstance(params[0], IsA) and
1004                      params[0]._IsSubClass(expected)):
1005                    params = params[1:]
1006
1007        # Check that each positional param is valid.
1008        for i in range(len(params)):
1009            try:
1010                arg_name = self._args[i]
1011            except IndexError:
1012                if not self._has_varargs:
1013                    raise AttributeError(
1014                        '%s does not take %d or more positional '
1015                        'arguments' % (self._method.__name__, i))
1016            else:
1017                self._RecordArgumentGiven(arg_name, arg_status)
1018
1019        # Check each keyword argument.
1020        for arg_name in named_params:
1021            if arg_name not in arg_status and not self._has_varkw:
1022                raise AttributeError('%s is not expecting keyword argument %s'
1023                                     % (self._method.__name__, arg_name))
1024            self._RecordArgumentGiven(arg_name, arg_status)
1025
1026        # Ensure all the required arguments have been given.
1027        still_needed = [k for k, v in arg_status.items()
1028                        if v == MethodSignatureChecker._NEEDED]
1029        if still_needed:
1030            raise AttributeError('No values given for arguments: %s'
1031                                 % (' '.join(sorted(still_needed))))
1032
1033
1034class MockMethod(object):
1035    """Callable mock method.
1036
1037    A MockMethod should act exactly like the method it mocks, accepting
1038    parameters and returning a value, or throwing an exception (as specified).
1039    When this method is called, it can optionally verify whether the called
1040    method (name and signature) matches the expected method.
1041    """
1042
1043    def __init__(self, method_name, call_queue, replay_mode,
1044                 method_to_mock=None, description=None, class_to_bind=None):
1045        """Construct a new mock method.
1046
1047        Args:
1048            # method_name: the name of the method
1049            # call_queue: deque of calls, verify this call against the head,
1050            #             or add this call to the queue.
1051            # replay_mode: False if we are recording, True if we are verifying
1052            #              calls against the call queue.
1053            # method_to_mock: The actual method being mocked, used for
1054            #                 introspection.
1055            # description: optionally, a descriptive name for this method.
1056            #              Typically this is equal to the descriptive name of
1057            #              the method's class.
1058            # class_to_bind: optionally, a class that is used for unbound
1059            #                methods (or functions in Python3) to which method
1060            #                is bound, in order not to loose binding
1061            #                information. If given, it will be used for
1062            #                checking the type of first method parameter
1063            method_name: str
1064            call_queue: list or deque
1065            replay_mode: bool
1066            method_to_mock: a method object
1067            description: str or None
1068            class_to_bind: type or None
1069        """
1070
1071        self._name = method_name
1072        self.__name__ = method_name
1073        self._call_queue = call_queue
1074        if not isinstance(call_queue, collections.deque):
1075            self._call_queue = collections.deque(self._call_queue)
1076        self._replay_mode = replay_mode
1077        self._description = description
1078
1079        self._params = None
1080        self._named_params = None
1081        self._return_value = None
1082        self._exception = None
1083        self._side_effects = None
1084
1085        try:
1086            self._checker = MethodSignatureChecker(method_to_mock,
1087                                                   class_to_bind=class_to_bind)
1088        except ValueError:
1089            self._checker = None
1090
1091    def __call__(self, *params, **named_params):
1092        """Log parameters and return the specified return value.
1093
1094        If the Mock(Anything/Object) associated with this call is in record
1095        mode, this MockMethod will be pushed onto the expected call queue.
1096        If the mock is in replay mode, this will pop a MockMethod off the
1097        top of the queue and verify this call is equal to the expected call.
1098
1099        Raises:
1100            UnexpectedMethodCall if this call is supposed to match an expected
1101                method call and it does not.
1102        """
1103
1104        self._params = params
1105        self._named_params = named_params
1106
1107        if not self._replay_mode:
1108            if self._checker is not None:
1109                self._checker.Check(params, named_params)
1110            self._call_queue.append(self)
1111            return self
1112
1113        expected_method = self._VerifyMethodCall()
1114
1115        if expected_method._side_effects:
1116            result = expected_method._side_effects(*params, **named_params)
1117            if expected_method._return_value is None:
1118                expected_method._return_value = result
1119
1120        if expected_method._exception:
1121            raise expected_method._exception
1122
1123        return expected_method._return_value
1124
1125    def __getattr__(self, name):
1126        """Raise an AttributeError with a helpful message."""
1127
1128        raise AttributeError(
1129            'MockMethod has no attribute "%s". '
1130            'Did you remember to put your mocks in replay mode?' % name)
1131
1132    def __iter__(self):
1133        """Raise a TypeError with a helpful message."""
1134        raise TypeError(
1135            'MockMethod cannot be iterated. '
1136            'Did you remember to put your mocks in replay mode?')
1137
1138    def next(self):
1139        """Raise a TypeError with a helpful message."""
1140        raise TypeError(
1141            'MockMethod cannot be iterated. '
1142            'Did you remember to put your mocks in replay mode?')
1143
1144    def __next__(self):
1145        """Raise a TypeError with a helpful message."""
1146        raise TypeError(
1147            'MockMethod cannot be iterated. '
1148            'Did you remember to put your mocks in replay mode?')
1149
1150    def _PopNextMethod(self):
1151        """Pop the next method from our call queue."""
1152        try:
1153            return self._call_queue.popleft()
1154        except IndexError:
1155            raise UnexpectedMethodCallError(self, None)
1156
1157    def _VerifyMethodCall(self):
1158        """Verify the called method is expected.
1159
1160        This can be an ordered method, or part of an unordered set.
1161
1162        Returns:
1163            The expected mock method.
1164
1165        Raises:
1166            UnexpectedMethodCall if the method called was not expected.
1167        """
1168
1169        expected = self._PopNextMethod()
1170
1171        # Loop here, because we might have a MethodGroup followed by another
1172        # group.
1173        while isinstance(expected, MethodGroup):
1174            expected, method = expected.MethodCalled(self)
1175            if method is not None:
1176                return method
1177
1178        # This is a mock method, so just check equality.
1179        if expected != self:
1180            raise UnexpectedMethodCallError(self, expected)
1181
1182        return expected
1183
1184    def __str__(self):
1185        params = ', '.join(
1186            [repr(p) for p in self._params or []] +
1187            ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
1188        full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
1189        if self._description:
1190            full_desc = "%s.%s" % (self._description, full_desc)
1191        return full_desc
1192
1193    def __hash__(self):
1194        return id(self)
1195
1196    def __eq__(self, rhs):
1197        """Test whether this MockMethod is equivalent to another MockMethod.
1198
1199        Args:
1200            # rhs: the right hand side of the test
1201            rhs: MockMethod
1202        """
1203
1204        return (isinstance(rhs, MockMethod) and
1205                self._name == rhs._name and
1206                self._params == rhs._params and
1207                self._named_params == rhs._named_params)
1208
1209    def __ne__(self, rhs):
1210        """Test if this MockMethod is not equivalent to another MockMethod.
1211
1212        Args:
1213            # rhs: the right hand side of the test
1214            rhs: MockMethod
1215        """
1216
1217        return not self == rhs
1218
1219    def GetPossibleGroup(self):
1220        """Returns a possible group from the end of the call queue.
1221
1222        Return None if no other methods are on the stack.
1223        """
1224
1225        # Remove this method from the tail of the queue so we can add it
1226        # to a group.
1227        this_method = self._call_queue.pop()
1228        assert this_method == self
1229
1230        # Determine if the tail of the queue is a group, or just a regular
1231        # ordered mock method.
1232        group = None
1233        try:
1234            group = self._call_queue[-1]
1235        except IndexError:
1236            pass
1237
1238        return group
1239
1240    def _CheckAndCreateNewGroup(self, group_name, group_class):
1241        """Checks if the last method (a possible group) is an instance of our
1242        group_class. Adds the current method to this group or creates a
1243        new one.
1244
1245        Args:
1246
1247            group_name: the name of the group.
1248            group_class: the class used to create instance of this new group
1249        """
1250        group = self.GetPossibleGroup()
1251
1252        # If this is a group, and it is the correct group, add the method.
1253        if isinstance(group, group_class) and group.group_name() == group_name:
1254            group.AddMethod(self)
1255            return self
1256
1257        # Create a new group and add the method.
1258        new_group = group_class(group_name)
1259        new_group.AddMethod(self)
1260        self._call_queue.append(new_group)
1261        return self
1262
1263    def InAnyOrder(self, group_name="default"):
1264        """Move this method into a group of unordered calls.
1265
1266        A group of unordered calls must be defined together, and must be
1267        executed in full before the next expected method can be called.
1268        There can be multiple groups that are expected serially, if they are
1269        given different group names. The same group name can be reused if there
1270        is a standard method call, or a group with a different name, spliced
1271        between usages.
1272
1273        Args:
1274            group_name: the name of the unordered group.
1275
1276        Returns:
1277            self
1278        """
1279        return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
1280
1281    def MultipleTimes(self, group_name="default"):
1282        """Move method into group of calls which may be called multiple times.
1283
1284        A group of repeating calls must be defined together, and must be
1285        executed in full before the next expected method can be called.
1286
1287        Args:
1288            group_name: the name of the unordered group.
1289
1290        Returns:
1291            self
1292        """
1293        return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
1294
1295    def AndReturn(self, return_value):
1296        """Set the value to return when this method is called.
1297
1298        Args:
1299            # return_value can be anything.
1300        """
1301
1302        self._return_value = return_value
1303        return return_value
1304
1305    def AndRaise(self, exception):
1306        """Set the exception to raise when this method is called.
1307
1308        Args:
1309            # exception: the exception to raise when this method is called.
1310            exception: Exception
1311        """
1312
1313        self._exception = exception
1314
1315    def WithSideEffects(self, side_effects):
1316        """Set the side effects that are simulated when this method is called.
1317
1318        Args:
1319            side_effects: A callable which modifies the parameters or other
1320                          relevant state which a given test case depends on.
1321
1322        Returns:
1323            Self for chaining with AndReturn and AndRaise.
1324        """
1325        self._side_effects = side_effects
1326        return self
1327
1328
1329class Comparator:
1330    """Base class for all Mox comparators.
1331
1332    A Comparator can be used as a parameter to a mocked method when the exact
1333    value is not known.    For example, the code you are testing might build up
1334    a long SQL string that is passed to your mock DAO. You're only interested
1335    that the IN clause contains the proper primary keys, so you can set your
1336    mock up as follows:
1337
1338    mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
1339
1340    Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
1341
1342    A Comparator may replace one or more parameters, for example:
1343    # return at most 10 rows
1344    mock_dao.RunQuery(StrContains('SELECT'), 10)
1345
1346    or
1347
1348    # Return some non-deterministic number of rows
1349    mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
1350    """
1351
1352    def equals(self, rhs):
1353        """Special equals method that all comparators must implement.
1354
1355        Args:
1356            rhs: any python object
1357        """
1358
1359        raise NotImplementedError('method must be implemented by a subclass.')
1360
1361    def __eq__(self, rhs):
1362        return self.equals(rhs)
1363
1364    def __ne__(self, rhs):
1365        return not self.equals(rhs)
1366
1367
1368class Is(Comparator):
1369    """Comparison class used to check identity, instead of equality."""
1370
1371    def __init__(self, obj):
1372        self._obj = obj
1373
1374    def equals(self, rhs):
1375        return rhs is self._obj
1376
1377    def __repr__(self):
1378        return "<is %r (%s)>" % (self._obj, id(self._obj))
1379
1380
1381class IsA(Comparator):
1382    """This class wraps a basic Python type or class.    It is used to verify
1383    that a parameter is of the given type or class.
1384
1385    Example:
1386    mock_dao.Connect(IsA(DbConnectInfo))
1387    """
1388
1389    def __init__(self, class_name):
1390        """Initialize IsA
1391
1392        Args:
1393            class_name: basic python type or a class
1394        """
1395
1396        self._class_name = class_name
1397
1398    def equals(self, rhs):
1399        """Check to see if the RHS is an instance of class_name.
1400
1401        Args:
1402            # rhs: the right hand side of the test
1403            rhs: object
1404
1405        Returns:
1406            bool
1407        """
1408
1409        try:
1410            return isinstance(rhs, self._class_name)
1411        except TypeError:
1412            # Check raw types if there was a type error.    This is helpful for
1413            # things like cStringIO.StringIO.
1414            return type(rhs) == type(self._class_name)
1415
1416    def _IsSubClass(self, clazz):
1417        """Check to see if the IsA comparators class is a subclass of clazz.
1418
1419        Args:
1420            # clazz: a class object
1421
1422        Returns:
1423            bool
1424        """
1425
1426        try:
1427            return issubclass(self._class_name, clazz)
1428        except TypeError:
1429            # Check raw types if there was a type error.    This is helpful for
1430            # things like cStringIO.StringIO.
1431            return type(clazz) == type(self._class_name)
1432
1433    def __repr__(self):
1434        return 'mox.IsA(%s) ' % str(self._class_name)
1435
1436
1437class IsAlmost(Comparator):
1438    """Comparison class used to check whether a parameter is nearly equal
1439    to a given value.    Generally useful for floating point numbers.
1440
1441    Example mock_dao.SetTimeout((IsAlmost(3.9)))
1442    """
1443
1444    def __init__(self, float_value, places=7):
1445        """Initialize IsAlmost.
1446
1447        Args:
1448            float_value: The value for making the comparison.
1449            places: The number of decimal places to round to.
1450        """
1451
1452        self._float_value = float_value
1453        self._places = places
1454
1455    def equals(self, rhs):
1456        """Check to see if RHS is almost equal to float_value
1457
1458        Args:
1459            rhs: the value to compare to float_value
1460
1461        Returns:
1462            bool
1463        """
1464
1465        try:
1466            return round(rhs - self._float_value, self._places) == 0
1467        except Exception:
1468            # Probably because either float_value or rhs is not a number.
1469            return False
1470
1471    def __repr__(self):
1472        return str(self._float_value)
1473
1474
1475class StrContains(Comparator):
1476    """Comparison class used to check whether a substring exists in a
1477    string parameter.    This can be useful in mocking a database with SQL
1478    passed in as a string parameter, for example.
1479
1480    Example:
1481    mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
1482    """
1483
1484    def __init__(self, search_string):
1485        """Initialize.
1486
1487        Args:
1488            # search_string: the string you are searching for
1489            search_string: str
1490        """
1491
1492        self._search_string = search_string
1493
1494    def equals(self, rhs):
1495        """Check to see if the search_string is contained in the rhs string.
1496
1497        Args:
1498            # rhs: the right hand side of the test
1499            rhs: object
1500
1501        Returns:
1502            bool
1503        """
1504
1505        try:
1506            return rhs.find(self._search_string) > -1
1507        except Exception:
1508            return False
1509
1510    def __repr__(self):
1511        return '<str containing \'%s\'>' % self._search_string
1512
1513
1514class Regex(Comparator):
1515    """Checks if a string matches a regular expression.
1516
1517    This uses a given regular expression to determine equality.
1518    """
1519
1520    def __init__(self, pattern, flags=0):
1521        """Initialize.
1522
1523        Args:
1524            # pattern is the regular expression to search for
1525            pattern: str
1526            # flags passed to re.compile function as the second argument
1527            flags: int
1528        """
1529        self.flags = flags
1530        self.regex = re.compile(pattern, flags=flags)
1531
1532    def equals(self, rhs):
1533        """Check to see if rhs matches regular expression pattern.
1534
1535        Returns:
1536            bool
1537        """
1538
1539        try:
1540            return self.regex.search(rhs) is not None
1541        except Exception:
1542            return False
1543
1544    def __repr__(self):
1545        s = '<regular expression \'%s\'' % self.regex.pattern
1546        if self.flags:
1547            s += ', flags=%d' % self.flags
1548        s += '>'
1549        return s
1550
1551
1552class In(Comparator):
1553    """Checks whether an item (or key) is in a list (or dict) parameter.
1554
1555    Example:
1556    mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
1557    """
1558
1559    def __init__(self, key):
1560        """Initialize.
1561
1562        Args:
1563            # key is any thing that could be in a list or a key in a dict
1564        """
1565
1566        self._key = key
1567
1568    def equals(self, rhs):
1569        """Check to see whether key is in rhs.
1570
1571        Args:
1572            rhs: dict
1573
1574        Returns:
1575            bool
1576        """
1577
1578        try:
1579            return self._key in rhs
1580        except Exception:
1581            return False
1582
1583    def __repr__(self):
1584        return '<sequence or map containing \'%s\'>' % str(self._key)
1585
1586
1587class Not(Comparator):
1588    """Checks whether a predicates is False.
1589
1590    Example:
1591        mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm',
1592                                                  stevepm_user_info)))
1593    """
1594
1595    def __init__(self, predicate):
1596        """Initialize.
1597
1598        Args:
1599            # predicate: a Comparator instance.
1600        """
1601
1602        assert isinstance(predicate, Comparator), ("predicate %r must be a"
1603                                                   " Comparator." % predicate)
1604        self._predicate = predicate
1605
1606    def equals(self, rhs):
1607        """Check to see whether the predicate is False.
1608
1609        Args:
1610            rhs: A value that will be given in argument of the predicate.
1611
1612        Returns:
1613            bool
1614        """
1615
1616        try:
1617            return not self._predicate.equals(rhs)
1618        except Exception:
1619            return False
1620
1621    def __repr__(self):
1622        return '<not \'%s\'>' % self._predicate
1623
1624
1625class ContainsKeyValue(Comparator):
1626    """Checks whether a key/value pair is in a dict parameter.
1627
1628    Example:
1629    mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
1630    """
1631
1632    def __init__(self, key, value):
1633        """Initialize.
1634
1635        Args:
1636            # key: a key in a dict
1637            # value: the corresponding value
1638        """
1639
1640        self._key = key
1641        self._value = value
1642
1643    def equals(self, rhs):
1644        """Check whether the given key/value pair is in the rhs dict.
1645
1646        Returns:
1647            bool
1648        """
1649
1650        try:
1651            return rhs[self._key] == self._value
1652        except Exception:
1653            return False
1654
1655    def __repr__(self):
1656        return '<map containing the entry \'%s: %s\'>' % (str(self._key),
1657                                                          str(self._value))
1658
1659
1660class ContainsAttributeValue(Comparator):
1661    """Checks whether passed parameter contains attributes with a given value.
1662
1663    Example:
1664    mock_dao.UpdateSomething(ContainsAttribute('stevepm', stevepm_user_info))
1665    """
1666
1667    def __init__(self, key, value):
1668        """Initialize.
1669
1670        Args:
1671            # key: an attribute name of an object
1672            # value: the corresponding value
1673        """
1674
1675        self._key = key
1676        self._value = value
1677
1678    def equals(self, rhs):
1679        """Check if the given attribute has a matching value in the rhs object.
1680
1681        Returns:
1682            bool
1683        """
1684
1685        try:
1686            return getattr(rhs, self._key) == self._value
1687        except Exception:
1688            return False
1689
1690
1691class SameElementsAs(Comparator):
1692    """Checks whether sequences contain the same elements (ignoring order).
1693
1694    Example:
1695    mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
1696    """
1697
1698    def __init__(self, expected_seq):
1699        """Initialize.
1700
1701        Args:
1702            expected_seq: a sequence
1703        """
1704        # Store in case expected_seq is an iterator.
1705        self._expected_list = list(expected_seq)
1706
1707    def equals(self, actual_seq):
1708        """Check to see whether actual_seq has same elements as expected_seq.
1709
1710        Args:
1711            actual_seq: sequence
1712
1713        Returns:
1714            bool
1715        """
1716        try:
1717            # Store in case actual_seq is an iterator. We potentially iterate
1718            # twice: once to make the dict, once in the list fallback.
1719            actual_list = list(actual_seq)
1720        except TypeError:
1721            # actual_seq cannot be read as a sequence.
1722            #
1723            # This happens because Mox uses __eq__ both to check object
1724            # equality (in MethodSignatureChecker) and to invoke Comparators.
1725            return False
1726
1727        try:
1728            return set(self._expected_list) == set(actual_list)
1729        except TypeError:
1730            # Fall back to slower list-compare if any of the objects
1731            # are unhashable.
1732            if len(self._expected_list) != len(actual_list):
1733                return False
1734            for el in actual_list:
1735                if el not in self._expected_list:
1736                    return False
1737        return True
1738
1739    def __repr__(self):
1740        return '<sequence with same elements as \'%s\'>' % self._expected_list
1741
1742
1743class And(Comparator):
1744    """Evaluates one or more Comparators on RHS, returns an AND of the results.
1745    """
1746
1747    def __init__(self, *args):
1748        """Initialize.
1749
1750        Args:
1751            *args: One or more Comparator
1752        """
1753
1754        self._comparators = args
1755
1756    def equals(self, rhs):
1757        """Checks whether all Comparators are equal to rhs.
1758
1759        Args:
1760            # rhs: can be anything
1761
1762        Returns:
1763            bool
1764        """
1765
1766        for comparator in self._comparators:
1767            if not comparator.equals(rhs):
1768                return False
1769
1770        return True
1771
1772    def __repr__(self):
1773        return '<AND %s>' % str(self._comparators)
1774
1775
1776class Or(Comparator):
1777    """Evaluates one or more Comparators on RHS; returns OR of the results."""
1778
1779    def __init__(self, *args):
1780        """Initialize.
1781
1782        Args:
1783            *args: One or more Mox comparators
1784        """
1785
1786        self._comparators = args
1787
1788    def equals(self, rhs):
1789        """Checks whether any Comparator is equal to rhs.
1790
1791        Args:
1792            # rhs: can be anything
1793
1794        Returns:
1795            bool
1796        """
1797
1798        for comparator in self._comparators:
1799            if comparator.equals(rhs):
1800                return True
1801
1802        return False
1803
1804    def __repr__(self):
1805        return '<OR %s>' % str(self._comparators)
1806
1807
1808class Func(Comparator):
1809    """Call a function that should verify the parameter passed in is correct.
1810
1811    You may need the ability to perform more advanced operations on the
1812    parameter in order to validate it. You can use this to have a callable
1813    validate any parameter. The callable should return either True or False.
1814
1815
1816    Example:
1817
1818    def myParamValidator(param):
1819        # Advanced logic here
1820        return True
1821
1822    mock_dao.DoSomething(Func(myParamValidator), true)
1823    """
1824
1825    def __init__(self, func):
1826        """Initialize.
1827
1828        Args:
1829            func: callable that takes one parameter and returns a bool
1830        """
1831
1832        self._func = func
1833
1834    def equals(self, rhs):
1835        """Test whether rhs passes the function test.
1836
1837        rhs is passed into func.
1838
1839        Args:
1840            rhs: any python object
1841
1842        Returns:
1843            the result of func(rhs)
1844        """
1845
1846        return self._func(rhs)
1847
1848    def __repr__(self):
1849        return str(self._func)
1850
1851
1852class IgnoreArg(Comparator):
1853    """Ignore an argument.
1854
1855    This can be used when we don't care about an argument of a method call.
1856
1857    Example:
1858    # Check if CastMagic is called with 3 as first arg and
1859    # 'disappear' as third.
1860    mymock.CastMagic(3, IgnoreArg(), 'disappear')
1861    """
1862
1863    def equals(self, unused_rhs):
1864        """Ignores arguments and returns True.
1865
1866        Args:
1867            unused_rhs: any python object
1868
1869        Returns:
1870            always returns True
1871        """
1872
1873        return True
1874
1875    def __repr__(self):
1876        return '<IgnoreArg>'
1877
1878
1879class Value(Comparator):
1880    """Compares argument against a remembered value.
1881
1882    To be used in conjunction with Remember comparator.    See Remember()
1883    for example.
1884    """
1885
1886    def __init__(self):
1887        self._value = None
1888        self._has_value = False
1889
1890    def store_value(self, rhs):
1891        self._value = rhs
1892        self._has_value = True
1893
1894    def equals(self, rhs):
1895        if not self._has_value:
1896            return False
1897        else:
1898            return rhs == self._value
1899
1900    def __repr__(self):
1901        if self._has_value:
1902            return "<Value %r>" % self._value
1903        else:
1904            return "<Value>"
1905
1906
1907class Remember(Comparator):
1908    """Remembers the argument to a value store.
1909
1910    To be used in conjunction with Value comparator.
1911
1912    Example:
1913    # Remember the argument for one method call.
1914    users_list = Value()
1915    mock_dao.ProcessUsers(Remember(users_list))
1916
1917    # Check argument against remembered value.
1918    mock_dao.ReportUsers(users_list)
1919    """
1920
1921    def __init__(self, value_store):
1922        if not isinstance(value_store, Value):
1923            raise TypeError(
1924                "value_store is not an instance of the Value class")
1925        self._value_store = value_store
1926
1927    def equals(self, rhs):
1928        self._value_store.store_value(rhs)
1929        return True
1930
1931    def __repr__(self):
1932        return "<Remember %d>" % id(self._value_store)
1933
1934
1935class MethodGroup(object):
1936    """Base class containing common behaviour for MethodGroups."""
1937
1938    def __init__(self, group_name):
1939        self._group_name = group_name
1940
1941    def group_name(self):
1942        return self._group_name
1943
1944    def __str__(self):
1945        return '<%s "%s">' % (self.__class__.__name__, self._group_name)
1946
1947    def AddMethod(self, mock_method):
1948        raise NotImplementedError
1949
1950    def MethodCalled(self, mock_method):
1951        raise NotImplementedError
1952
1953    def IsSatisfied(self):
1954        raise NotImplementedError
1955
1956
1957class UnorderedGroup(MethodGroup):
1958    """UnorderedGroup holds a set of method calls that may occur in any order.
1959
1960    This construct is helpful for non-deterministic events, such as iterating
1961    over the keys of a dict.
1962    """
1963
1964    def __init__(self, group_name):
1965        super(UnorderedGroup, self).__init__(group_name)
1966        self._methods = []
1967
1968    def __str__(self):
1969        return '%s "%s" pending calls:\n%s' % (
1970            self.__class__.__name__,
1971            self._group_name,
1972            "\n".join(str(method) for method in self._methods))
1973
1974    def AddMethod(self, mock_method):
1975        """Add a method to this group.
1976
1977        Args:
1978            mock_method: A mock method to be added to this group.
1979        """
1980
1981        self._methods.append(mock_method)
1982
1983    def MethodCalled(self, mock_method):
1984        """Remove a method call from the group.
1985
1986        If the method is not in the set, an UnexpectedMethodCallError will be
1987        raised.
1988
1989        Args:
1990            mock_method: a mock method that should be equal to a method in the
1991                         group.
1992
1993        Returns:
1994            The mock method from the group
1995
1996        Raises:
1997            UnexpectedMethodCallError if the mock_method was not in the group.
1998        """
1999
2000        # Check to see if this method exists, and if so, remove it from the set
2001        # and return it.
2002        for method in self._methods:
2003            if method == mock_method:
2004                # Remove the called mock_method instead of the method in the
2005                # group. The called method will match any comparators when
2006                # equality is checked during removal. The method in the group
2007                # could pass a comparator to another comparator during the
2008                # equality check.
2009                self._methods.remove(mock_method)
2010
2011                # If group is not empty, put it back at the head of the queue.
2012                if not self.IsSatisfied():
2013                    mock_method._call_queue.appendleft(self)
2014
2015                return self, method
2016
2017        raise UnexpectedMethodCallError(mock_method, self)
2018
2019    def IsSatisfied(self):
2020        """Return True if there are not any methods in this group."""
2021
2022        return len(self._methods) == 0
2023
2024
2025class MultipleTimesGroup(MethodGroup):
2026    """MultipleTimesGroup holds methods that may be called any number of times.
2027
2028    Note: Each method must be called at least once.
2029
2030    This is helpful, if you don't know or care how many times a method is
2031    called.
2032    """
2033
2034    def __init__(self, group_name):
2035        super(MultipleTimesGroup, self).__init__(group_name)
2036        self._methods = set()
2037        self._methods_left = set()
2038
2039    def AddMethod(self, mock_method):
2040        """Add a method to this group.
2041
2042        Args:
2043            mock_method: A mock method to be added to this group.
2044        """
2045
2046        self._methods.add(mock_method)
2047        self._methods_left.add(mock_method)
2048
2049    def MethodCalled(self, mock_method):
2050        """Remove a method call from the group.
2051
2052        If the method is not in the set, an UnexpectedMethodCallError will be
2053        raised.
2054
2055        Args:
2056            mock_method: a mock method that should be equal to a method in the
2057                         group.
2058
2059        Returns:
2060            The mock method from the group
2061
2062        Raises:
2063            UnexpectedMethodCallError if the mock_method was not in the group.
2064        """
2065
2066        # Check to see if this method exists, and if so add it to the set of
2067        # called methods.
2068        for method in self._methods:
2069            if method == mock_method:
2070                self._methods_left.discard(method)
2071                # Always put this group back on top of the queue,
2072                # because we don't know when we are done.
2073                mock_method._call_queue.appendleft(self)
2074                return self, method
2075
2076        if self.IsSatisfied():
2077            next_method = mock_method._PopNextMethod()
2078            return next_method, None
2079        else:
2080            raise UnexpectedMethodCallError(mock_method, self)
2081
2082    def IsSatisfied(self):
2083        """Return True if all methods in group are called at least once."""
2084        return len(self._methods_left) == 0
2085
2086
2087class MoxMetaTestBase(type):
2088    """Metaclass to add mox cleanup and verification to every test.
2089
2090    As the mox unit testing class is being constructed (MoxTestBase or a
2091    subclass), this metaclass will modify all test functions to call the
2092    CleanUpMox method of the test class after they finish. This means that
2093    unstubbing and verifying will happen for every test with no additional
2094    code, and any failures will result in test failures as opposed to errors.
2095    """
2096
2097    def __init__(cls, name, bases, d):
2098        type.__init__(cls, name, bases, d)
2099
2100        # also get all the attributes from the base classes to account
2101        # for a case when test class is not the immediate child of MoxTestBase
2102        for base in bases:
2103            for attr_name in dir(base):
2104                if attr_name not in d:
2105                    d[attr_name] = getattr(base, attr_name)
2106
2107        for func_name, func in d.items():
2108            if func_name.startswith('test') and callable(func):
2109
2110                setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
2111
2112    @staticmethod
2113    def CleanUpTest(cls, func):
2114        """Adds Mox cleanup code to any MoxTestBase method.
2115
2116        Always unsets stubs after a test. Will verify all mocks for tests that
2117        otherwise pass.
2118
2119        Args:
2120            cls: MoxTestBase or subclass; the class whose method we are
2121                                          altering.
2122            func: method; the method of the MoxTestBase test class we wish to
2123                          alter.
2124
2125        Returns:
2126            The modified method.
2127        """
2128        def new_method(self, *args, **kwargs):
2129            mox_obj = getattr(self, 'mox', None)
2130            stubout_obj = getattr(self, 'stubs', None)
2131            cleanup_mox = False
2132            cleanup_stubout = False
2133            if mox_obj and isinstance(mox_obj, Mox):
2134                cleanup_mox = True
2135            if stubout_obj and isinstance(stubout_obj,
2136                                          stubout.StubOutForTesting):
2137                cleanup_stubout = True
2138            try:
2139                func(self, *args, **kwargs)
2140            finally:
2141                if cleanup_mox:
2142                    mox_obj.UnsetStubs()
2143                if cleanup_stubout:
2144                    stubout_obj.UnsetAll()
2145                    stubout_obj.SmartUnsetAll()
2146            if cleanup_mox:
2147                mox_obj.VerifyAll()
2148        new_method.__name__ = func.__name__
2149        new_method.__doc__ = func.__doc__
2150        new_method.__module__ = func.__module__
2151        return new_method
2152
2153
2154_MoxTestBase = MoxMetaTestBase('_MoxTestBase', (unittest.TestCase, ), {})
2155
2156
2157class MoxTestBase(_MoxTestBase):
2158    """Convenience test class to make stubbing easier.
2159
2160    Sets up a "mox" attribute which is an instance of Mox (any mox tests will
2161    want this), and a "stubs" attribute that is an instance of
2162    StubOutForTesting (needed at times). Also automatically unsets any stubs
2163    and verifies that all mock methods have been called at the end of each
2164    test, eliminating boilerplate code.
2165    """
2166
2167    def setUp(self):
2168        super(MoxTestBase, self).setUp()
2169        self.mox = Mox()
2170        self.stubs = stubout.StubOutForTesting()
2171