1# Copyright 2008 Google Inc. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# 15# This is a fork of the pymox library intended to work with Python 3. 16# The file was modified by quermit@gmail.com and dawid.fatyga@gmail.com 17 18"""Mox, an object-mocking framework for Python. 19 20Mox works in the record-replay-verify paradigm. When you first create 21a mock object, it is in record mode. You then programmatically set 22the expected behavior of the mock object (what methods are to be 23called on it, with what parameters, what they should return, and in 24what order). 25 26Once you have set up the expected mock behavior, you put it in replay 27mode. Now the mock responds to method calls just as you told it to. 28If an unexpected method (or an expected method with unexpected 29parameters) is called, then an exception will be raised. 30 31Once you are done interacting with the mock, you need to verify that 32all the expected interactions occured. (Maybe your code exited 33prematurely without calling some cleanup method!) The verify phase 34ensures that every expected method was called; otherwise, an exception 35will be raised. 36 37WARNING! Mock objects created by Mox are not thread-safe. If you are 38call a mock in multiple threads, it should be guarded by a mutex. 39 40TODO(stevepm): Add the option to make mocks thread-safe! 41 42Suggested usage / workflow: 43 44 # Create Mox factory 45 my_mox = Mox() 46 47 # Create a mock data access object 48 mock_dao = my_mox.CreateMock(DAOClass) 49 50 # Set up expected behavior 51 mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person) 52 mock_dao.DeletePerson(person) 53 54 # Put mocks in replay mode 55 my_mox.ReplayAll() 56 57 # Inject mock object and run test 58 controller.SetDao(mock_dao) 59 controller.DeletePersonById('1') 60 61 # Verify all methods were called as expected 62 my_mox.VerifyAll() 63""" 64 65import collections 66import difflib 67import inspect 68import re 69import types 70import unittest 71 72from mox3 import stubout 73 74 75class Error(AssertionError): 76 """Base exception for this module.""" 77 78 pass 79 80 81class ExpectedMethodCallsError(Error): 82 """Raised when an expected method wasn't called. 83 84 This can occur if Verify() is called before all expected methods have been 85 called. 86 """ 87 88 def __init__(self, expected_methods): 89 """Init exception. 90 91 Args: 92 # expected_methods: A sequence of MockMethod objects that should 93 # have been called. 94 expected_methods: [MockMethod] 95 96 Raises: 97 ValueError: if expected_methods contains no methods. 98 """ 99 100 if not expected_methods: 101 raise ValueError("There must be at least one expected method") 102 Error.__init__(self) 103 self._expected_methods = expected_methods 104 105 def __str__(self): 106 calls = "\n".join(["%3d. %s" % (i, m) 107 for i, m in enumerate(self._expected_methods)]) 108 return "Verify: Expected methods never called:\n%s" % (calls,) 109 110 111class UnexpectedMethodCallError(Error): 112 """Raised when an unexpected method is called. 113 114 This can occur if a method is called with incorrect parameters, or out of 115 the specified order. 116 """ 117 118 def __init__(self, unexpected_method, expected): 119 """Init exception. 120 121 Args: 122 # unexpected_method: MockMethod that was called but was not at the 123 # head of the expected_method queue. 124 # expected: MockMethod or UnorderedGroup the method should have 125 # been in. 126 unexpected_method: MockMethod 127 expected: MockMethod or UnorderedGroup 128 """ 129 130 Error.__init__(self) 131 if expected is None: 132 self._str = "Unexpected method call %s" % (unexpected_method,) 133 else: 134 differ = difflib.Differ() 135 diff = differ.compare(str(unexpected_method).splitlines(True), 136 str(expected).splitlines(True)) 137 self._str = ("Unexpected method call." 138 " unexpected:- expected:+\n%s" 139 % ("\n".join(line.rstrip() for line in diff),)) 140 141 def __str__(self): 142 return self._str 143 144 145class UnknownMethodCallError(Error): 146 """Raised if an unknown method is requested of the mock object.""" 147 148 def __init__(self, unknown_method_name): 149 """Init exception. 150 151 Args: 152 # unknown_method_name: Method call that is not part of the mocked 153 # class's public interface. 154 unknown_method_name: str 155 """ 156 157 Error.__init__(self) 158 self._unknown_method_name = unknown_method_name 159 160 def __str__(self): 161 return ("Method called is not a member of the object: %s" % 162 self._unknown_method_name) 163 164 165class PrivateAttributeError(Error): 166 """Raised if a MockObject is passed a private additional attribute name.""" 167 168 def __init__(self, attr): 169 Error.__init__(self) 170 self._attr = attr 171 172 def __str__(self): 173 return ("Attribute '%s' is private and should not be available" 174 "in a mock object." % self._attr) 175 176 177class ExpectedMockCreationError(Error): 178 """Raised if mocks should have been created by StubOutClassWithMocks.""" 179 180 def __init__(self, expected_mocks): 181 """Init exception. 182 183 Args: 184 # expected_mocks: A sequence of MockObjects that should have been 185 # created 186 187 Raises: 188 ValueError: if expected_mocks contains no methods. 189 """ 190 191 if not expected_mocks: 192 raise ValueError("There must be at least one expected method") 193 Error.__init__(self) 194 self._expected_mocks = expected_mocks 195 196 def __str__(self): 197 mocks = "\n".join(["%3d. %s" % (i, m) 198 for i, m in enumerate(self._expected_mocks)]) 199 return "Verify: Expected mocks never created:\n%s" % (mocks,) 200 201 202class UnexpectedMockCreationError(Error): 203 """Raised if too many mocks were created by StubOutClassWithMocks.""" 204 205 def __init__(self, instance, *params, **named_params): 206 """Init exception. 207 208 Args: 209 # instance: the type of obejct that was created 210 # params: parameters given during instantiation 211 # named_params: named parameters given during instantiation 212 """ 213 214 Error.__init__(self) 215 self._instance = instance 216 self._params = params 217 self._named_params = named_params 218 219 def __str__(self): 220 args = ", ".join(["%s" % v for i, v in enumerate(self._params)]) 221 error = "Unexpected mock creation: %s(%s" % (self._instance, args) 222 223 if self._named_params: 224 error += ", " + ", ".join(["%s=%s" % (k, v) for k, v in 225 self._named_params.items()]) 226 227 error += ")" 228 return error 229 230 231class Mox(object): 232 """Mox: a factory for creating mock objects.""" 233 234 # A list of types that should be stubbed out with MockObjects (as 235 # opposed to MockAnythings). 236 _USE_MOCK_OBJECT = [types.FunctionType, types.ModuleType, types.MethodType] 237 238 def __init__(self): 239 """Initialize a new Mox.""" 240 241 self._mock_objects = [] 242 self.stubs = stubout.StubOutForTesting() 243 244 def CreateMock(self, class_to_mock, attrs=None, bounded_to=None): 245 """Create a new mock object. 246 247 Args: 248 # class_to_mock: the class to be mocked 249 class_to_mock: class 250 attrs: dict of attribute names to values that will be 251 set on the mock object. Only public attributes may be set. 252 bounded_to: optionally, when class_to_mock is not a class, 253 it points to a real class object, to which 254 attribute is bound 255 256 Returns: 257 MockObject that can be used as the class_to_mock would be. 258 """ 259 if attrs is None: 260 attrs = {} 261 new_mock = MockObject(class_to_mock, attrs=attrs, 262 class_to_bind=bounded_to) 263 self._mock_objects.append(new_mock) 264 return new_mock 265 266 def CreateMockAnything(self, description=None): 267 """Create a mock that will accept any method calls. 268 269 This does not enforce an interface. 270 271 Args: 272 description: str. Optionally, a descriptive name for the mock object 273 being created, for debugging output purposes. 274 """ 275 new_mock = MockAnything(description=description) 276 self._mock_objects.append(new_mock) 277 return new_mock 278 279 def ReplayAll(self): 280 """Set all mock objects to replay mode.""" 281 282 for mock_obj in self._mock_objects: 283 mock_obj._Replay() 284 285 def VerifyAll(self): 286 """Call verify on all mock objects created.""" 287 288 for mock_obj in self._mock_objects: 289 mock_obj._Verify() 290 291 def ResetAll(self): 292 """Call reset on all mock objects. This does not unset stubs.""" 293 294 for mock_obj in self._mock_objects: 295 mock_obj._Reset() 296 297 def StubOutWithMock(self, obj, attr_name, use_mock_anything=False): 298 """Replace a method, attribute, etc. with a Mock. 299 300 This will replace a class or module with a MockObject, and everything 301 else (method, function, etc) with a MockAnything. This can be 302 overridden to always use a MockAnything by setting use_mock_anything 303 to True. 304 305 Args: 306 obj: A Python object (class, module, instance, callable). 307 attr_name: str. The name of the attribute to replace with a mock. 308 use_mock_anything: bool. True if a MockAnything should be used 309 regardless of the type of attribute. 310 """ 311 312 if inspect.isclass(obj): 313 class_to_bind = obj 314 else: 315 class_to_bind = None 316 317 attr_to_replace = getattr(obj, attr_name) 318 attr_type = type(attr_to_replace) 319 320 if attr_type == MockAnything or attr_type == MockObject: 321 raise TypeError('Cannot mock a MockAnything! Did you remember to ' 322 'call UnsetStubs in your previous test?') 323 324 type_check = ( 325 attr_type in self._USE_MOCK_OBJECT or 326 inspect.isclass(attr_to_replace) or 327 isinstance(attr_to_replace, object)) 328 if type_check and not use_mock_anything: 329 stub = self.CreateMock(attr_to_replace, bounded_to=class_to_bind) 330 else: 331 stub = self.CreateMockAnything( 332 description='Stub for %s' % attr_to_replace) 333 stub.__name__ = attr_name 334 335 self.stubs.Set(obj, attr_name, stub) 336 337 def StubOutClassWithMocks(self, obj, attr_name): 338 """Replace a class with a "mock factory" that will create mock objects. 339 340 This is useful if the code-under-test directly instantiates 341 dependencies. Previously some boilder plate was necessary to 342 create a mock that would act as a factory. Using 343 StubOutClassWithMocks, once you've stubbed out the class you may 344 use the stubbed class as you would any other mock created by mox: 345 during the record phase, new mock instances will be created, and 346 during replay, the recorded mocks will be returned. 347 348 In replay mode 349 350 # Example using StubOutWithMock (the old, clunky way): 351 352 mock1 = mox.CreateMock(my_import.FooClass) 353 mock2 = mox.CreateMock(my_import.FooClass) 354 foo_factory = mox.StubOutWithMock(my_import, 'FooClass', 355 use_mock_anything=True) 356 foo_factory(1, 2).AndReturn(mock1) 357 foo_factory(9, 10).AndReturn(mock2) 358 mox.ReplayAll() 359 360 my_import.FooClass(1, 2) # Returns mock1 again. 361 my_import.FooClass(9, 10) # Returns mock2 again. 362 mox.VerifyAll() 363 364 # Example using StubOutClassWithMocks: 365 366 mox.StubOutClassWithMocks(my_import, 'FooClass') 367 mock1 = my_import.FooClass(1, 2) # Returns a new mock of FooClass 368 mock2 = my_import.FooClass(9, 10) # Returns another mock instance 369 mox.ReplayAll() 370 371 my_import.FooClass(1, 2) # Returns mock1 again. 372 my_import.FooClass(9, 10) # Returns mock2 again. 373 mox.VerifyAll() 374 """ 375 attr_to_replace = getattr(obj, attr_name) 376 attr_type = type(attr_to_replace) 377 378 if attr_type == MockAnything or attr_type == MockObject: 379 raise TypeError('Cannot mock a MockAnything! Did you remember to ' 380 'call UnsetStubs in your previous test?') 381 382 if not inspect.isclass(attr_to_replace): 383 raise TypeError('Given attr is not a Class. Use StubOutWithMock.') 384 385 factory = _MockObjectFactory(attr_to_replace, self) 386 self._mock_objects.append(factory) 387 self.stubs.Set(obj, attr_name, factory) 388 389 def UnsetStubs(self): 390 """Restore stubs to their original state.""" 391 392 self.stubs.UnsetAll() 393 394 395def Replay(*args): 396 """Put mocks into Replay mode. 397 398 Args: 399 # args is any number of mocks to put into replay mode. 400 """ 401 402 for mock in args: 403 mock._Replay() 404 405 406def Verify(*args): 407 """Verify mocks. 408 409 Args: 410 # args is any number of mocks to be verified. 411 """ 412 413 for mock in args: 414 mock._Verify() 415 416 417def Reset(*args): 418 """Reset mocks. 419 420 Args: 421 # args is any number of mocks to be reset. 422 """ 423 424 for mock in args: 425 mock._Reset() 426 427 428class MockAnything(object): 429 """A mock that can be used to mock anything. 430 431 This is helpful for mocking classes that do not provide a public interface. 432 """ 433 434 def __init__(self, description=None): 435 """Initialize a new MockAnything. 436 437 Args: 438 description: str. Optionally, a descriptive name for the mock 439 object being created, for debugging output purposes. 440 """ 441 self._description = description 442 self._Reset() 443 444 def __repr__(self): 445 if self._description: 446 return '<MockAnything instance of %s>' % self._description 447 else: 448 return '<MockAnything instance>' 449 450 def __getattr__(self, method_name): 451 """Intercept method calls on this object. 452 453 A new MockMethod is returned that is aware of the MockAnything's 454 state (record or replay). The call will be recorded or replayed 455 by the MockMethod's __call__. 456 457 Args: 458 # method name: the name of the method being called. 459 method_name: str 460 461 Returns: 462 A new MockMethod aware of MockAnything's state (record or replay). 463 """ 464 if method_name == '__dir__': 465 return self.__class__.__dir__.__get__(self, self.__class__) 466 467 return self._CreateMockMethod(method_name) 468 469 def __str__(self): 470 return self._CreateMockMethod('__str__')() 471 472 def __call__(self, *args, **kwargs): 473 return self._CreateMockMethod('__call__')(*args, **kwargs) 474 475 def __getitem__(self, i): 476 return self._CreateMockMethod('__getitem__')(i) 477 478 def _CreateMockMethod(self, method_name, method_to_mock=None, 479 class_to_bind=object): 480 """Create a new mock method call and return it. 481 482 Args: 483 # method_name: the name of the method being called. 484 # method_to_mock: The actual method being mocked, used for 485 # introspection. 486 # class_to_bind: Class to which method is bounded 487 # (object by default) 488 method_name: str 489 method_to_mock: a method object 490 491 Returns: 492 A new MockMethod aware of MockAnything's state (record or replay). 493 """ 494 495 return MockMethod(method_name, self._expected_calls_queue, 496 self._replay_mode, method_to_mock=method_to_mock, 497 description=self._description, 498 class_to_bind=class_to_bind) 499 500 def __nonzero__(self): 501 """Return 1 for nonzero so the mock can be used as a conditional.""" 502 503 return 1 504 505 def __bool__(self): 506 """Return True for nonzero so the mock can be used as a conditional.""" 507 return True 508 509 def __eq__(self, rhs): 510 """Provide custom logic to compare objects.""" 511 512 return (isinstance(rhs, MockAnything) and 513 self._replay_mode == rhs._replay_mode and 514 self._expected_calls_queue == rhs._expected_calls_queue) 515 516 def __ne__(self, rhs): 517 """Provide custom logic to compare objects.""" 518 519 return not self == rhs 520 521 def _Replay(self): 522 """Start replaying expected method calls.""" 523 524 self._replay_mode = True 525 526 def _Verify(self): 527 """Verify that all of the expected calls have been made. 528 529 Raises: 530 ExpectedMethodCallsError: if there are still more method calls in 531 the expected queue. 532 """ 533 534 # If the list of expected calls is not empty, raise an exception 535 if self._expected_calls_queue: 536 # The last MultipleTimesGroup is not popped from the queue. 537 if (len(self._expected_calls_queue) == 1 and 538 isinstance(self._expected_calls_queue[0], 539 MultipleTimesGroup) and 540 self._expected_calls_queue[0].IsSatisfied()): 541 pass 542 else: 543 raise ExpectedMethodCallsError(self._expected_calls_queue) 544 545 def _Reset(self): 546 """Reset the state of this mock to record mode with an empty queue.""" 547 548 # Maintain a list of method calls we are expecting 549 self._expected_calls_queue = collections.deque() 550 551 # Make sure we are in setup mode, not replay mode 552 self._replay_mode = False 553 554 555class MockObject(MockAnything): 556 """Mock object that simulates the public/protected interface of a class.""" 557 558 def __init__(self, class_to_mock, attrs=None, class_to_bind=None): 559 """Initialize a mock object. 560 561 Determines the methods and properties of the class and stores them. 562 563 Args: 564 # class_to_mock: class to be mocked 565 class_to_mock: class 566 attrs: dict of attribute names to values that will be set on the 567 mock object. Only public attributes may be set. 568 class_to_bind: optionally, when class_to_mock is not a class at 569 all, it points to a real class 570 571 Raises: 572 PrivateAttributeError: if a supplied attribute is not public. 573 ValueError: if an attribute would mask an existing method. 574 """ 575 if attrs is None: 576 attrs = {} 577 578 # Used to hack around the mixin/inheritance of MockAnything, which 579 # is not a proper object (it can be anything. :-) 580 MockAnything.__dict__['__init__'](self) 581 582 # Get a list of all the public and special methods we should mock. 583 self._known_methods = set() 584 self._known_vars = set() 585 self._class_to_mock = class_to_mock 586 587 if inspect.isclass(class_to_mock): 588 self._class_to_bind = self._class_to_mock 589 else: 590 self._class_to_bind = class_to_bind 591 592 try: 593 if inspect.isclass(self._class_to_mock): 594 self._description = class_to_mock.__name__ 595 else: 596 self._description = type(class_to_mock).__name__ 597 except Exception: 598 pass 599 600 for method in dir(class_to_mock): 601 attr = getattr(class_to_mock, method) 602 if callable(attr): 603 self._known_methods.add(method) 604 elif not (type(attr) is property): 605 # treating properties as class vars makes little sense. 606 self._known_vars.add(method) 607 608 # Set additional attributes at instantiation time; this is quicker 609 # than manually setting attributes that are normally created in 610 # __init__. 611 for attr, value in attrs.items(): 612 if attr.startswith("_"): 613 raise PrivateAttributeError(attr) 614 elif attr in self._known_methods: 615 raise ValueError("'%s' is a method of '%s' objects." % (attr, 616 class_to_mock)) 617 else: 618 setattr(self, attr, value) 619 620 def _CreateMockMethod(self, *args, **kwargs): 621 """Overridden to provide self._class_to_mock to class_to_bind.""" 622 kwargs.setdefault("class_to_bind", self._class_to_bind) 623 return super(MockObject, self)._CreateMockMethod(*args, **kwargs) 624 625 def __getattr__(self, name): 626 """Intercept attribute request on this object. 627 628 If the attribute is a public class variable, it will be returned and 629 not recorded as a call. 630 631 If the attribute is not a variable, it is handled like a method 632 call. The method name is checked against the set of mockable 633 methods, and a new MockMethod is returned that is aware of the 634 MockObject's state (record or replay). The call will be recorded 635 or replayed by the MockMethod's __call__. 636 637 Args: 638 # name: the name of the attribute being requested. 639 name: str 640 641 Returns: 642 Either a class variable or a new MockMethod that is aware of the 643 state of the mock (record or replay). 644 645 Raises: 646 UnknownMethodCallError if the MockObject does not mock the 647 requested method. 648 """ 649 650 if name in self._known_vars: 651 return getattr(self._class_to_mock, name) 652 653 if name in self._known_methods: 654 return self._CreateMockMethod( 655 name, 656 method_to_mock=getattr(self._class_to_mock, name)) 657 658 raise UnknownMethodCallError(name) 659 660 def __eq__(self, rhs): 661 """Provide custom logic to compare objects.""" 662 663 return (isinstance(rhs, MockObject) and 664 self._class_to_mock == rhs._class_to_mock and 665 self._replay_mode == rhs._replay_mode and 666 self._expected_calls_queue == rhs._expected_calls_queue) 667 668 def __setitem__(self, key, value): 669 """Custom logic for mocking classes that support item assignment. 670 671 Args: 672 key: Key to set the value for. 673 value: Value to set. 674 675 Returns: 676 Expected return value in replay mode. A MockMethod object for the 677 __setitem__ method that has already been called if not in replay 678 mode. 679 680 Raises: 681 TypeError if the underlying class does not support item assignment. 682 UnexpectedMethodCallError if the object does not expect the call to 683 __setitem__. 684 685 """ 686 # Verify the class supports item assignment. 687 if '__setitem__' not in dir(self._class_to_mock): 688 raise TypeError('object does not support item assignment') 689 690 # If we are in replay mode then simply call the mock __setitem__ method 691 if self._replay_mode: 692 return MockMethod('__setitem__', self._expected_calls_queue, 693 self._replay_mode)(key, value) 694 695 # Otherwise, create a mock method __setitem__. 696 return self._CreateMockMethod('__setitem__')(key, value) 697 698 def __getitem__(self, key): 699 """Provide custom logic for mocking classes that are subscriptable. 700 701 Args: 702 key: Key to return the value for. 703 704 Returns: 705 Expected return value in replay mode. A MockMethod object for the 706 __getitem__ method that has already been called if not in replay 707 mode. 708 709 Raises: 710 TypeError if the underlying class is not subscriptable. 711 UnexpectedMethodCallError if the object does not expect the call to 712 __getitem__. 713 714 """ 715 # Verify the class supports item assignment. 716 if '__getitem__' not in dir(self._class_to_mock): 717 raise TypeError('unsubscriptable object') 718 719 # If we are in replay mode then simply call the mock __getitem__ method 720 if self._replay_mode: 721 return MockMethod('__getitem__', self._expected_calls_queue, 722 self._replay_mode)(key) 723 724 # Otherwise, create a mock method __getitem__. 725 return self._CreateMockMethod('__getitem__')(key) 726 727 def __iter__(self): 728 """Provide custom logic for mocking classes that are iterable. 729 730 Returns: 731 Expected return value in replay mode. A MockMethod object for the 732 __iter__ method that has already been called if not in replay mode. 733 734 Raises: 735 TypeError if the underlying class is not iterable. 736 UnexpectedMethodCallError if the object does not expect the call to 737 __iter__. 738 739 """ 740 methods = dir(self._class_to_mock) 741 742 # Verify the class supports iteration. 743 if '__iter__' not in methods: 744 # If it doesn't have iter method and we are in replay method, 745 # then try to iterate using subscripts. 746 if '__getitem__' not in methods or not self._replay_mode: 747 raise TypeError('not iterable object') 748 else: 749 results = [] 750 index = 0 751 try: 752 while True: 753 results.append(self[index]) 754 index += 1 755 except IndexError: 756 return iter(results) 757 758 # If we are in replay mode then simply call the mock __iter__ method. 759 if self._replay_mode: 760 return MockMethod('__iter__', self._expected_calls_queue, 761 self._replay_mode)() 762 763 # Otherwise, create a mock method __iter__. 764 return self._CreateMockMethod('__iter__')() 765 766 def __contains__(self, key): 767 """Provide custom logic for mocking classes that contain items. 768 769 Args: 770 key: Key to look in container for. 771 772 Returns: 773 Expected return value in replay mode. A MockMethod object for the 774 __contains__ method that has already been called if not in replay 775 mode. 776 777 Raises: 778 TypeError if the underlying class does not implement __contains__ 779 UnexpectedMethodCaller if the object does not expect the call to 780 __contains__. 781 782 """ 783 contains = self._class_to_mock.__dict__.get('__contains__', None) 784 785 if contains is None: 786 raise TypeError('unsubscriptable object') 787 788 if self._replay_mode: 789 return MockMethod('__contains__', self._expected_calls_queue, 790 self._replay_mode)(key) 791 792 return self._CreateMockMethod('__contains__')(key) 793 794 def __call__(self, *params, **named_params): 795 """Provide custom logic for mocking classes that are callable.""" 796 797 # Verify the class we are mocking is callable. 798 is_callable = hasattr(self._class_to_mock, '__call__') 799 if not is_callable: 800 raise TypeError('Not callable') 801 802 # Because the call is happening directly on this object instead of 803 # a method, the call on the mock method is made right here 804 805 # If we are mocking a Function, then use the function, and not the 806 # __call__ method 807 method = None 808 if type(self._class_to_mock) in (types.FunctionType, types.MethodType): 809 method = self._class_to_mock 810 else: 811 method = getattr(self._class_to_mock, '__call__') 812 mock_method = self._CreateMockMethod('__call__', method_to_mock=method) 813 814 return mock_method(*params, **named_params) 815 816 @property 817 def __name__(self): 818 """Return the name that is being mocked.""" 819 return self._description 820 821 # TODO(dejw): this property stopped to work after I introduced changes with 822 # binding classes. Fortunately I found a solution in the form of 823 # __getattribute__ method below, but this issue should be investigated 824 @property 825 def __class__(self): 826 return self._class_to_mock 827 828 def __dir__(self): 829 """Return only attributes of a class to mock.""" 830 return dir(self._class_to_mock) 831 832 def __getattribute__(self, name): 833 """Return _class_to_mock on __class__ attribute.""" 834 if name == "__class__": 835 return super(MockObject, self).__getattribute__("_class_to_mock") 836 837 return super(MockObject, self).__getattribute__(name) 838 839 840class _MockObjectFactory(MockObject): 841 """A MockObjectFactory creates mocks and verifies __init__ params. 842 843 A MockObjectFactory removes the boiler plate code that was previously 844 necessary to stub out direction instantiation of a class. 845 846 The MockObjectFactory creates new MockObjects when called and verifies the 847 __init__ params are correct when in record mode. When replaying, 848 existing mocks are returned, and the __init__ params are verified. 849 850 See StubOutWithMock vs StubOutClassWithMocks for more detail. 851 """ 852 853 def __init__(self, class_to_mock, mox_instance): 854 MockObject.__init__(self, class_to_mock) 855 self._mox = mox_instance 856 self._instance_queue = collections.deque() 857 858 def __call__(self, *params, **named_params): 859 """Instantiate and record that a new mock has been created.""" 860 861 method = getattr(self._class_to_mock, '__init__') 862 mock_method = self._CreateMockMethod('__init__', method_to_mock=method) 863 # Note: calling mock_method() is deferred in order to catch the 864 # empty instance_queue first. 865 866 if self._replay_mode: 867 if not self._instance_queue: 868 raise UnexpectedMockCreationError(self._class_to_mock, *params, 869 **named_params) 870 871 mock_method(*params, **named_params) 872 873 return self._instance_queue.pop() 874 else: 875 mock_method(*params, **named_params) 876 877 instance = self._mox.CreateMock(self._class_to_mock) 878 self._instance_queue.appendleft(instance) 879 return instance 880 881 def _Verify(self): 882 """Verify that all mocks have been created.""" 883 if self._instance_queue: 884 raise ExpectedMockCreationError(self._instance_queue) 885 super(_MockObjectFactory, self)._Verify() 886 887 888class MethodSignatureChecker(object): 889 """Ensures that methods are called correctly.""" 890 891 _NEEDED, _DEFAULT, _GIVEN = range(3) 892 893 def __init__(self, method, class_to_bind=None): 894 """Creates a checker. 895 896 Args: 897 # method: A method to check. 898 # class_to_bind: optionally, a class used to type check first 899 # method parameter, only used with unbound methods 900 method: function 901 class_to_bind: type or None 902 903 Raises: 904 ValueError: method could not be inspected, so checks aren't 905 possible. Some methods and functions like built-ins 906 can't be inspected. 907 """ 908 try: 909 self._args, varargs, varkw, defaults = inspect.getargspec(method) 910 except TypeError: 911 raise ValueError('Could not get argument specification for %r' 912 % (method,)) 913 if (inspect.ismethod(method) or class_to_bind or ( 914 hasattr(self, '_args') and len(self._args) > 0 and 915 self._args[0] == 'self')): 916 self._args = self._args[1:] # Skip 'self'. 917 self._method = method 918 self._instance = None # May contain the instance this is bound to. 919 self._instance = getattr(method, "__self__", None) 920 921 # _bounded_to determines whether the method is bound or not 922 if self._instance: 923 self._bounded_to = self._instance.__class__ 924 else: 925 self._bounded_to = class_to_bind or getattr(method, "im_class", 926 None) 927 928 self._has_varargs = varargs is not None 929 self._has_varkw = varkw is not None 930 if defaults is None: 931 self._required_args = self._args 932 self._default_args = [] 933 else: 934 self._required_args = self._args[:-len(defaults)] 935 self._default_args = self._args[-len(defaults):] 936 937 def _RecordArgumentGiven(self, arg_name, arg_status): 938 """Mark an argument as being given. 939 940 Args: 941 # arg_name: The name of the argument to mark in arg_status. 942 # arg_status: Maps argument names to one of 943 # _NEEDED, _DEFAULT, _GIVEN. 944 arg_name: string 945 arg_status: dict 946 947 Raises: 948 AttributeError: arg_name is already marked as _GIVEN. 949 """ 950 if arg_status.get(arg_name, None) == MethodSignatureChecker._GIVEN: 951 raise AttributeError('%s provided more than once' % (arg_name,)) 952 arg_status[arg_name] = MethodSignatureChecker._GIVEN 953 954 def Check(self, params, named_params): 955 """Ensures that the parameters used while recording a call are valid. 956 957 Args: 958 # params: A list of positional parameters. 959 # named_params: A dict of named parameters. 960 params: list 961 named_params: dict 962 963 Raises: 964 AttributeError: the given parameters don't work with the given 965 method. 966 """ 967 arg_status = dict((a, MethodSignatureChecker._NEEDED) 968 for a in self._required_args) 969 for arg in self._default_args: 970 arg_status[arg] = MethodSignatureChecker._DEFAULT 971 972 # WARNING: Suspect hack ahead. 973 # 974 # Check to see if this is an unbound method, where the instance 975 # should be bound as the first argument. We try to determine if 976 # the first argument (param[0]) is an instance of the class, or it 977 # is equivalent to the class (used to account for Comparators). 978 # 979 # NOTE: If a Func() comparator is used, and the signature is not 980 # correct, this will cause extra executions of the function. 981 if inspect.ismethod(self._method) or self._bounded_to: 982 # The extra param accounts for the bound instance. 983 if len(params) > len(self._required_args): 984 expected = self._bounded_to 985 986 # Check if the param is an instance of the expected class, 987 # or check equality (useful for checking Comparators). 988 989 # This is a hack to work around the fact that the first 990 # parameter can be a Comparator, and the comparison may raise 991 # an exception during this comparison, which is OK. 992 try: 993 param_equality = (params[0] == expected) 994 except Exception: 995 param_equality = False 996 997 if isinstance(params[0], expected) or param_equality: 998 params = params[1:] 999 # If the IsA() comparator is being used, we need to check the 1000 # inverse of the usual case - that the given instance is a 1001 # subclass of the expected class. For example, the code under 1002 # test does late binding to a subclass. 1003 elif (isinstance(params[0], IsA) and 1004 params[0]._IsSubClass(expected)): 1005 params = params[1:] 1006 1007 # Check that each positional param is valid. 1008 for i in range(len(params)): 1009 try: 1010 arg_name = self._args[i] 1011 except IndexError: 1012 if not self._has_varargs: 1013 raise AttributeError( 1014 '%s does not take %d or more positional ' 1015 'arguments' % (self._method.__name__, i)) 1016 else: 1017 self._RecordArgumentGiven(arg_name, arg_status) 1018 1019 # Check each keyword argument. 1020 for arg_name in named_params: 1021 if arg_name not in arg_status and not self._has_varkw: 1022 raise AttributeError('%s is not expecting keyword argument %s' 1023 % (self._method.__name__, arg_name)) 1024 self._RecordArgumentGiven(arg_name, arg_status) 1025 1026 # Ensure all the required arguments have been given. 1027 still_needed = [k for k, v in arg_status.items() 1028 if v == MethodSignatureChecker._NEEDED] 1029 if still_needed: 1030 raise AttributeError('No values given for arguments: %s' 1031 % (' '.join(sorted(still_needed)))) 1032 1033 1034class MockMethod(object): 1035 """Callable mock method. 1036 1037 A MockMethod should act exactly like the method it mocks, accepting 1038 parameters and returning a value, or throwing an exception (as specified). 1039 When this method is called, it can optionally verify whether the called 1040 method (name and signature) matches the expected method. 1041 """ 1042 1043 def __init__(self, method_name, call_queue, replay_mode, 1044 method_to_mock=None, description=None, class_to_bind=None): 1045 """Construct a new mock method. 1046 1047 Args: 1048 # method_name: the name of the method 1049 # call_queue: deque of calls, verify this call against the head, 1050 # or add this call to the queue. 1051 # replay_mode: False if we are recording, True if we are verifying 1052 # calls against the call queue. 1053 # method_to_mock: The actual method being mocked, used for 1054 # introspection. 1055 # description: optionally, a descriptive name for this method. 1056 # Typically this is equal to the descriptive name of 1057 # the method's class. 1058 # class_to_bind: optionally, a class that is used for unbound 1059 # methods (or functions in Python3) to which method 1060 # is bound, in order not to loose binding 1061 # information. If given, it will be used for 1062 # checking the type of first method parameter 1063 method_name: str 1064 call_queue: list or deque 1065 replay_mode: bool 1066 method_to_mock: a method object 1067 description: str or None 1068 class_to_bind: type or None 1069 """ 1070 1071 self._name = method_name 1072 self.__name__ = method_name 1073 self._call_queue = call_queue 1074 if not isinstance(call_queue, collections.deque): 1075 self._call_queue = collections.deque(self._call_queue) 1076 self._replay_mode = replay_mode 1077 self._description = description 1078 1079 self._params = None 1080 self._named_params = None 1081 self._return_value = None 1082 self._exception = None 1083 self._side_effects = None 1084 1085 try: 1086 self._checker = MethodSignatureChecker(method_to_mock, 1087 class_to_bind=class_to_bind) 1088 except ValueError: 1089 self._checker = None 1090 1091 def __call__(self, *params, **named_params): 1092 """Log parameters and return the specified return value. 1093 1094 If the Mock(Anything/Object) associated with this call is in record 1095 mode, this MockMethod will be pushed onto the expected call queue. 1096 If the mock is in replay mode, this will pop a MockMethod off the 1097 top of the queue and verify this call is equal to the expected call. 1098 1099 Raises: 1100 UnexpectedMethodCall if this call is supposed to match an expected 1101 method call and it does not. 1102 """ 1103 1104 self._params = params 1105 self._named_params = named_params 1106 1107 if not self._replay_mode: 1108 if self._checker is not None: 1109 self._checker.Check(params, named_params) 1110 self._call_queue.append(self) 1111 return self 1112 1113 expected_method = self._VerifyMethodCall() 1114 1115 if expected_method._side_effects: 1116 result = expected_method._side_effects(*params, **named_params) 1117 if expected_method._return_value is None: 1118 expected_method._return_value = result 1119 1120 if expected_method._exception: 1121 raise expected_method._exception 1122 1123 return expected_method._return_value 1124 1125 def __getattr__(self, name): 1126 """Raise an AttributeError with a helpful message.""" 1127 1128 raise AttributeError( 1129 'MockMethod has no attribute "%s". ' 1130 'Did you remember to put your mocks in replay mode?' % name) 1131 1132 def __iter__(self): 1133 """Raise a TypeError with a helpful message.""" 1134 raise TypeError( 1135 'MockMethod cannot be iterated. ' 1136 'Did you remember to put your mocks in replay mode?') 1137 1138 def next(self): 1139 """Raise a TypeError with a helpful message.""" 1140 raise TypeError( 1141 'MockMethod cannot be iterated. ' 1142 'Did you remember to put your mocks in replay mode?') 1143 1144 def __next__(self): 1145 """Raise a TypeError with a helpful message.""" 1146 raise TypeError( 1147 'MockMethod cannot be iterated. ' 1148 'Did you remember to put your mocks in replay mode?') 1149 1150 def _PopNextMethod(self): 1151 """Pop the next method from our call queue.""" 1152 try: 1153 return self._call_queue.popleft() 1154 except IndexError: 1155 raise UnexpectedMethodCallError(self, None) 1156 1157 def _VerifyMethodCall(self): 1158 """Verify the called method is expected. 1159 1160 This can be an ordered method, or part of an unordered set. 1161 1162 Returns: 1163 The expected mock method. 1164 1165 Raises: 1166 UnexpectedMethodCall if the method called was not expected. 1167 """ 1168 1169 expected = self._PopNextMethod() 1170 1171 # Loop here, because we might have a MethodGroup followed by another 1172 # group. 1173 while isinstance(expected, MethodGroup): 1174 expected, method = expected.MethodCalled(self) 1175 if method is not None: 1176 return method 1177 1178 # This is a mock method, so just check equality. 1179 if expected != self: 1180 raise UnexpectedMethodCallError(self, expected) 1181 1182 return expected 1183 1184 def __str__(self): 1185 params = ', '.join( 1186 [repr(p) for p in self._params or []] + 1187 ['%s=%r' % x for x in sorted((self._named_params or {}).items())]) 1188 full_desc = "%s(%s) -> %r" % (self._name, params, self._return_value) 1189 if self._description: 1190 full_desc = "%s.%s" % (self._description, full_desc) 1191 return full_desc 1192 1193 def __hash__(self): 1194 return id(self) 1195 1196 def __eq__(self, rhs): 1197 """Test whether this MockMethod is equivalent to another MockMethod. 1198 1199 Args: 1200 # rhs: the right hand side of the test 1201 rhs: MockMethod 1202 """ 1203 1204 return (isinstance(rhs, MockMethod) and 1205 self._name == rhs._name and 1206 self._params == rhs._params and 1207 self._named_params == rhs._named_params) 1208 1209 def __ne__(self, rhs): 1210 """Test if this MockMethod is not equivalent to another MockMethod. 1211 1212 Args: 1213 # rhs: the right hand side of the test 1214 rhs: MockMethod 1215 """ 1216 1217 return not self == rhs 1218 1219 def GetPossibleGroup(self): 1220 """Returns a possible group from the end of the call queue. 1221 1222 Return None if no other methods are on the stack. 1223 """ 1224 1225 # Remove this method from the tail of the queue so we can add it 1226 # to a group. 1227 this_method = self._call_queue.pop() 1228 assert this_method == self 1229 1230 # Determine if the tail of the queue is a group, or just a regular 1231 # ordered mock method. 1232 group = None 1233 try: 1234 group = self._call_queue[-1] 1235 except IndexError: 1236 pass 1237 1238 return group 1239 1240 def _CheckAndCreateNewGroup(self, group_name, group_class): 1241 """Checks if the last method (a possible group) is an instance of our 1242 group_class. Adds the current method to this group or creates a 1243 new one. 1244 1245 Args: 1246 1247 group_name: the name of the group. 1248 group_class: the class used to create instance of this new group 1249 """ 1250 group = self.GetPossibleGroup() 1251 1252 # If this is a group, and it is the correct group, add the method. 1253 if isinstance(group, group_class) and group.group_name() == group_name: 1254 group.AddMethod(self) 1255 return self 1256 1257 # Create a new group and add the method. 1258 new_group = group_class(group_name) 1259 new_group.AddMethod(self) 1260 self._call_queue.append(new_group) 1261 return self 1262 1263 def InAnyOrder(self, group_name="default"): 1264 """Move this method into a group of unordered calls. 1265 1266 A group of unordered calls must be defined together, and must be 1267 executed in full before the next expected method can be called. 1268 There can be multiple groups that are expected serially, if they are 1269 given different group names. The same group name can be reused if there 1270 is a standard method call, or a group with a different name, spliced 1271 between usages. 1272 1273 Args: 1274 group_name: the name of the unordered group. 1275 1276 Returns: 1277 self 1278 """ 1279 return self._CheckAndCreateNewGroup(group_name, UnorderedGroup) 1280 1281 def MultipleTimes(self, group_name="default"): 1282 """Move method into group of calls which may be called multiple times. 1283 1284 A group of repeating calls must be defined together, and must be 1285 executed in full before the next expected method can be called. 1286 1287 Args: 1288 group_name: the name of the unordered group. 1289 1290 Returns: 1291 self 1292 """ 1293 return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup) 1294 1295 def AndReturn(self, return_value): 1296 """Set the value to return when this method is called. 1297 1298 Args: 1299 # return_value can be anything. 1300 """ 1301 1302 self._return_value = return_value 1303 return return_value 1304 1305 def AndRaise(self, exception): 1306 """Set the exception to raise when this method is called. 1307 1308 Args: 1309 # exception: the exception to raise when this method is called. 1310 exception: Exception 1311 """ 1312 1313 self._exception = exception 1314 1315 def WithSideEffects(self, side_effects): 1316 """Set the side effects that are simulated when this method is called. 1317 1318 Args: 1319 side_effects: A callable which modifies the parameters or other 1320 relevant state which a given test case depends on. 1321 1322 Returns: 1323 Self for chaining with AndReturn and AndRaise. 1324 """ 1325 self._side_effects = side_effects 1326 return self 1327 1328 1329class Comparator: 1330 """Base class for all Mox comparators. 1331 1332 A Comparator can be used as a parameter to a mocked method when the exact 1333 value is not known. For example, the code you are testing might build up 1334 a long SQL string that is passed to your mock DAO. You're only interested 1335 that the IN clause contains the proper primary keys, so you can set your 1336 mock up as follows: 1337 1338 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) 1339 1340 Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'. 1341 1342 A Comparator may replace one or more parameters, for example: 1343 # return at most 10 rows 1344 mock_dao.RunQuery(StrContains('SELECT'), 10) 1345 1346 or 1347 1348 # Return some non-deterministic number of rows 1349 mock_dao.RunQuery(StrContains('SELECT'), IsA(int)) 1350 """ 1351 1352 def equals(self, rhs): 1353 """Special equals method that all comparators must implement. 1354 1355 Args: 1356 rhs: any python object 1357 """ 1358 1359 raise NotImplementedError('method must be implemented by a subclass.') 1360 1361 def __eq__(self, rhs): 1362 return self.equals(rhs) 1363 1364 def __ne__(self, rhs): 1365 return not self.equals(rhs) 1366 1367 1368class Is(Comparator): 1369 """Comparison class used to check identity, instead of equality.""" 1370 1371 def __init__(self, obj): 1372 self._obj = obj 1373 1374 def equals(self, rhs): 1375 return rhs is self._obj 1376 1377 def __repr__(self): 1378 return "<is %r (%s)>" % (self._obj, id(self._obj)) 1379 1380 1381class IsA(Comparator): 1382 """This class wraps a basic Python type or class. It is used to verify 1383 that a parameter is of the given type or class. 1384 1385 Example: 1386 mock_dao.Connect(IsA(DbConnectInfo)) 1387 """ 1388 1389 def __init__(self, class_name): 1390 """Initialize IsA 1391 1392 Args: 1393 class_name: basic python type or a class 1394 """ 1395 1396 self._class_name = class_name 1397 1398 def equals(self, rhs): 1399 """Check to see if the RHS is an instance of class_name. 1400 1401 Args: 1402 # rhs: the right hand side of the test 1403 rhs: object 1404 1405 Returns: 1406 bool 1407 """ 1408 1409 try: 1410 return isinstance(rhs, self._class_name) 1411 except TypeError: 1412 # Check raw types if there was a type error. This is helpful for 1413 # things like cStringIO.StringIO. 1414 return type(rhs) == type(self._class_name) 1415 1416 def _IsSubClass(self, clazz): 1417 """Check to see if the IsA comparators class is a subclass of clazz. 1418 1419 Args: 1420 # clazz: a class object 1421 1422 Returns: 1423 bool 1424 """ 1425 1426 try: 1427 return issubclass(self._class_name, clazz) 1428 except TypeError: 1429 # Check raw types if there was a type error. This is helpful for 1430 # things like cStringIO.StringIO. 1431 return type(clazz) == type(self._class_name) 1432 1433 def __repr__(self): 1434 return 'mox.IsA(%s) ' % str(self._class_name) 1435 1436 1437class IsAlmost(Comparator): 1438 """Comparison class used to check whether a parameter is nearly equal 1439 to a given value. Generally useful for floating point numbers. 1440 1441 Example mock_dao.SetTimeout((IsAlmost(3.9))) 1442 """ 1443 1444 def __init__(self, float_value, places=7): 1445 """Initialize IsAlmost. 1446 1447 Args: 1448 float_value: The value for making the comparison. 1449 places: The number of decimal places to round to. 1450 """ 1451 1452 self._float_value = float_value 1453 self._places = places 1454 1455 def equals(self, rhs): 1456 """Check to see if RHS is almost equal to float_value 1457 1458 Args: 1459 rhs: the value to compare to float_value 1460 1461 Returns: 1462 bool 1463 """ 1464 1465 try: 1466 return round(rhs - self._float_value, self._places) == 0 1467 except Exception: 1468 # Probably because either float_value or rhs is not a number. 1469 return False 1470 1471 def __repr__(self): 1472 return str(self._float_value) 1473 1474 1475class StrContains(Comparator): 1476 """Comparison class used to check whether a substring exists in a 1477 string parameter. This can be useful in mocking a database with SQL 1478 passed in as a string parameter, for example. 1479 1480 Example: 1481 mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) 1482 """ 1483 1484 def __init__(self, search_string): 1485 """Initialize. 1486 1487 Args: 1488 # search_string: the string you are searching for 1489 search_string: str 1490 """ 1491 1492 self._search_string = search_string 1493 1494 def equals(self, rhs): 1495 """Check to see if the search_string is contained in the rhs string. 1496 1497 Args: 1498 # rhs: the right hand side of the test 1499 rhs: object 1500 1501 Returns: 1502 bool 1503 """ 1504 1505 try: 1506 return rhs.find(self._search_string) > -1 1507 except Exception: 1508 return False 1509 1510 def __repr__(self): 1511 return '<str containing \'%s\'>' % self._search_string 1512 1513 1514class Regex(Comparator): 1515 """Checks if a string matches a regular expression. 1516 1517 This uses a given regular expression to determine equality. 1518 """ 1519 1520 def __init__(self, pattern, flags=0): 1521 """Initialize. 1522 1523 Args: 1524 # pattern is the regular expression to search for 1525 pattern: str 1526 # flags passed to re.compile function as the second argument 1527 flags: int 1528 """ 1529 self.flags = flags 1530 self.regex = re.compile(pattern, flags=flags) 1531 1532 def equals(self, rhs): 1533 """Check to see if rhs matches regular expression pattern. 1534 1535 Returns: 1536 bool 1537 """ 1538 1539 try: 1540 return self.regex.search(rhs) is not None 1541 except Exception: 1542 return False 1543 1544 def __repr__(self): 1545 s = '<regular expression \'%s\'' % self.regex.pattern 1546 if self.flags: 1547 s += ', flags=%d' % self.flags 1548 s += '>' 1549 return s 1550 1551 1552class In(Comparator): 1553 """Checks whether an item (or key) is in a list (or dict) parameter. 1554 1555 Example: 1556 mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result) 1557 """ 1558 1559 def __init__(self, key): 1560 """Initialize. 1561 1562 Args: 1563 # key is any thing that could be in a list or a key in a dict 1564 """ 1565 1566 self._key = key 1567 1568 def equals(self, rhs): 1569 """Check to see whether key is in rhs. 1570 1571 Args: 1572 rhs: dict 1573 1574 Returns: 1575 bool 1576 """ 1577 1578 try: 1579 return self._key in rhs 1580 except Exception: 1581 return False 1582 1583 def __repr__(self): 1584 return '<sequence or map containing \'%s\'>' % str(self._key) 1585 1586 1587class Not(Comparator): 1588 """Checks whether a predicates is False. 1589 1590 Example: 1591 mock_dao.UpdateUsers(Not(ContainsKeyValue('stevepm', 1592 stevepm_user_info))) 1593 """ 1594 1595 def __init__(self, predicate): 1596 """Initialize. 1597 1598 Args: 1599 # predicate: a Comparator instance. 1600 """ 1601 1602 assert isinstance(predicate, Comparator), ("predicate %r must be a" 1603 " Comparator." % predicate) 1604 self._predicate = predicate 1605 1606 def equals(self, rhs): 1607 """Check to see whether the predicate is False. 1608 1609 Args: 1610 rhs: A value that will be given in argument of the predicate. 1611 1612 Returns: 1613 bool 1614 """ 1615 1616 try: 1617 return not self._predicate.equals(rhs) 1618 except Exception: 1619 return False 1620 1621 def __repr__(self): 1622 return '<not \'%s\'>' % self._predicate 1623 1624 1625class ContainsKeyValue(Comparator): 1626 """Checks whether a key/value pair is in a dict parameter. 1627 1628 Example: 1629 mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info)) 1630 """ 1631 1632 def __init__(self, key, value): 1633 """Initialize. 1634 1635 Args: 1636 # key: a key in a dict 1637 # value: the corresponding value 1638 """ 1639 1640 self._key = key 1641 self._value = value 1642 1643 def equals(self, rhs): 1644 """Check whether the given key/value pair is in the rhs dict. 1645 1646 Returns: 1647 bool 1648 """ 1649 1650 try: 1651 return rhs[self._key] == self._value 1652 except Exception: 1653 return False 1654 1655 def __repr__(self): 1656 return '<map containing the entry \'%s: %s\'>' % (str(self._key), 1657 str(self._value)) 1658 1659 1660class ContainsAttributeValue(Comparator): 1661 """Checks whether passed parameter contains attributes with a given value. 1662 1663 Example: 1664 mock_dao.UpdateSomething(ContainsAttribute('stevepm', stevepm_user_info)) 1665 """ 1666 1667 def __init__(self, key, value): 1668 """Initialize. 1669 1670 Args: 1671 # key: an attribute name of an object 1672 # value: the corresponding value 1673 """ 1674 1675 self._key = key 1676 self._value = value 1677 1678 def equals(self, rhs): 1679 """Check if the given attribute has a matching value in the rhs object. 1680 1681 Returns: 1682 bool 1683 """ 1684 1685 try: 1686 return getattr(rhs, self._key) == self._value 1687 except Exception: 1688 return False 1689 1690 1691class SameElementsAs(Comparator): 1692 """Checks whether sequences contain the same elements (ignoring order). 1693 1694 Example: 1695 mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki')) 1696 """ 1697 1698 def __init__(self, expected_seq): 1699 """Initialize. 1700 1701 Args: 1702 expected_seq: a sequence 1703 """ 1704 # Store in case expected_seq is an iterator. 1705 self._expected_list = list(expected_seq) 1706 1707 def equals(self, actual_seq): 1708 """Check to see whether actual_seq has same elements as expected_seq. 1709 1710 Args: 1711 actual_seq: sequence 1712 1713 Returns: 1714 bool 1715 """ 1716 try: 1717 # Store in case actual_seq is an iterator. We potentially iterate 1718 # twice: once to make the dict, once in the list fallback. 1719 actual_list = list(actual_seq) 1720 except TypeError: 1721 # actual_seq cannot be read as a sequence. 1722 # 1723 # This happens because Mox uses __eq__ both to check object 1724 # equality (in MethodSignatureChecker) and to invoke Comparators. 1725 return False 1726 1727 try: 1728 return set(self._expected_list) == set(actual_list) 1729 except TypeError: 1730 # Fall back to slower list-compare if any of the objects 1731 # are unhashable. 1732 if len(self._expected_list) != len(actual_list): 1733 return False 1734 for el in actual_list: 1735 if el not in self._expected_list: 1736 return False 1737 return True 1738 1739 def __repr__(self): 1740 return '<sequence with same elements as \'%s\'>' % self._expected_list 1741 1742 1743class And(Comparator): 1744 """Evaluates one or more Comparators on RHS, returns an AND of the results. 1745 """ 1746 1747 def __init__(self, *args): 1748 """Initialize. 1749 1750 Args: 1751 *args: One or more Comparator 1752 """ 1753 1754 self._comparators = args 1755 1756 def equals(self, rhs): 1757 """Checks whether all Comparators are equal to rhs. 1758 1759 Args: 1760 # rhs: can be anything 1761 1762 Returns: 1763 bool 1764 """ 1765 1766 for comparator in self._comparators: 1767 if not comparator.equals(rhs): 1768 return False 1769 1770 return True 1771 1772 def __repr__(self): 1773 return '<AND %s>' % str(self._comparators) 1774 1775 1776class Or(Comparator): 1777 """Evaluates one or more Comparators on RHS; returns OR of the results.""" 1778 1779 def __init__(self, *args): 1780 """Initialize. 1781 1782 Args: 1783 *args: One or more Mox comparators 1784 """ 1785 1786 self._comparators = args 1787 1788 def equals(self, rhs): 1789 """Checks whether any Comparator is equal to rhs. 1790 1791 Args: 1792 # rhs: can be anything 1793 1794 Returns: 1795 bool 1796 """ 1797 1798 for comparator in self._comparators: 1799 if comparator.equals(rhs): 1800 return True 1801 1802 return False 1803 1804 def __repr__(self): 1805 return '<OR %s>' % str(self._comparators) 1806 1807 1808class Func(Comparator): 1809 """Call a function that should verify the parameter passed in is correct. 1810 1811 You may need the ability to perform more advanced operations on the 1812 parameter in order to validate it. You can use this to have a callable 1813 validate any parameter. The callable should return either True or False. 1814 1815 1816 Example: 1817 1818 def myParamValidator(param): 1819 # Advanced logic here 1820 return True 1821 1822 mock_dao.DoSomething(Func(myParamValidator), true) 1823 """ 1824 1825 def __init__(self, func): 1826 """Initialize. 1827 1828 Args: 1829 func: callable that takes one parameter and returns a bool 1830 """ 1831 1832 self._func = func 1833 1834 def equals(self, rhs): 1835 """Test whether rhs passes the function test. 1836 1837 rhs is passed into func. 1838 1839 Args: 1840 rhs: any python object 1841 1842 Returns: 1843 the result of func(rhs) 1844 """ 1845 1846 return self._func(rhs) 1847 1848 def __repr__(self): 1849 return str(self._func) 1850 1851 1852class IgnoreArg(Comparator): 1853 """Ignore an argument. 1854 1855 This can be used when we don't care about an argument of a method call. 1856 1857 Example: 1858 # Check if CastMagic is called with 3 as first arg and 1859 # 'disappear' as third. 1860 mymock.CastMagic(3, IgnoreArg(), 'disappear') 1861 """ 1862 1863 def equals(self, unused_rhs): 1864 """Ignores arguments and returns True. 1865 1866 Args: 1867 unused_rhs: any python object 1868 1869 Returns: 1870 always returns True 1871 """ 1872 1873 return True 1874 1875 def __repr__(self): 1876 return '<IgnoreArg>' 1877 1878 1879class Value(Comparator): 1880 """Compares argument against a remembered value. 1881 1882 To be used in conjunction with Remember comparator. See Remember() 1883 for example. 1884 """ 1885 1886 def __init__(self): 1887 self._value = None 1888 self._has_value = False 1889 1890 def store_value(self, rhs): 1891 self._value = rhs 1892 self._has_value = True 1893 1894 def equals(self, rhs): 1895 if not self._has_value: 1896 return False 1897 else: 1898 return rhs == self._value 1899 1900 def __repr__(self): 1901 if self._has_value: 1902 return "<Value %r>" % self._value 1903 else: 1904 return "<Value>" 1905 1906 1907class Remember(Comparator): 1908 """Remembers the argument to a value store. 1909 1910 To be used in conjunction with Value comparator. 1911 1912 Example: 1913 # Remember the argument for one method call. 1914 users_list = Value() 1915 mock_dao.ProcessUsers(Remember(users_list)) 1916 1917 # Check argument against remembered value. 1918 mock_dao.ReportUsers(users_list) 1919 """ 1920 1921 def __init__(self, value_store): 1922 if not isinstance(value_store, Value): 1923 raise TypeError( 1924 "value_store is not an instance of the Value class") 1925 self._value_store = value_store 1926 1927 def equals(self, rhs): 1928 self._value_store.store_value(rhs) 1929 return True 1930 1931 def __repr__(self): 1932 return "<Remember %d>" % id(self._value_store) 1933 1934 1935class MethodGroup(object): 1936 """Base class containing common behaviour for MethodGroups.""" 1937 1938 def __init__(self, group_name): 1939 self._group_name = group_name 1940 1941 def group_name(self): 1942 return self._group_name 1943 1944 def __str__(self): 1945 return '<%s "%s">' % (self.__class__.__name__, self._group_name) 1946 1947 def AddMethod(self, mock_method): 1948 raise NotImplementedError 1949 1950 def MethodCalled(self, mock_method): 1951 raise NotImplementedError 1952 1953 def IsSatisfied(self): 1954 raise NotImplementedError 1955 1956 1957class UnorderedGroup(MethodGroup): 1958 """UnorderedGroup holds a set of method calls that may occur in any order. 1959 1960 This construct is helpful for non-deterministic events, such as iterating 1961 over the keys of a dict. 1962 """ 1963 1964 def __init__(self, group_name): 1965 super(UnorderedGroup, self).__init__(group_name) 1966 self._methods = [] 1967 1968 def __str__(self): 1969 return '%s "%s" pending calls:\n%s' % ( 1970 self.__class__.__name__, 1971 self._group_name, 1972 "\n".join(str(method) for method in self._methods)) 1973 1974 def AddMethod(self, mock_method): 1975 """Add a method to this group. 1976 1977 Args: 1978 mock_method: A mock method to be added to this group. 1979 """ 1980 1981 self._methods.append(mock_method) 1982 1983 def MethodCalled(self, mock_method): 1984 """Remove a method call from the group. 1985 1986 If the method is not in the set, an UnexpectedMethodCallError will be 1987 raised. 1988 1989 Args: 1990 mock_method: a mock method that should be equal to a method in the 1991 group. 1992 1993 Returns: 1994 The mock method from the group 1995 1996 Raises: 1997 UnexpectedMethodCallError if the mock_method was not in the group. 1998 """ 1999 2000 # Check to see if this method exists, and if so, remove it from the set 2001 # and return it. 2002 for method in self._methods: 2003 if method == mock_method: 2004 # Remove the called mock_method instead of the method in the 2005 # group. The called method will match any comparators when 2006 # equality is checked during removal. The method in the group 2007 # could pass a comparator to another comparator during the 2008 # equality check. 2009 self._methods.remove(mock_method) 2010 2011 # If group is not empty, put it back at the head of the queue. 2012 if not self.IsSatisfied(): 2013 mock_method._call_queue.appendleft(self) 2014 2015 return self, method 2016 2017 raise UnexpectedMethodCallError(mock_method, self) 2018 2019 def IsSatisfied(self): 2020 """Return True if there are not any methods in this group.""" 2021 2022 return len(self._methods) == 0 2023 2024 2025class MultipleTimesGroup(MethodGroup): 2026 """MultipleTimesGroup holds methods that may be called any number of times. 2027 2028 Note: Each method must be called at least once. 2029 2030 This is helpful, if you don't know or care how many times a method is 2031 called. 2032 """ 2033 2034 def __init__(self, group_name): 2035 super(MultipleTimesGroup, self).__init__(group_name) 2036 self._methods = set() 2037 self._methods_left = set() 2038 2039 def AddMethod(self, mock_method): 2040 """Add a method to this group. 2041 2042 Args: 2043 mock_method: A mock method to be added to this group. 2044 """ 2045 2046 self._methods.add(mock_method) 2047 self._methods_left.add(mock_method) 2048 2049 def MethodCalled(self, mock_method): 2050 """Remove a method call from the group. 2051 2052 If the method is not in the set, an UnexpectedMethodCallError will be 2053 raised. 2054 2055 Args: 2056 mock_method: a mock method that should be equal to a method in the 2057 group. 2058 2059 Returns: 2060 The mock method from the group 2061 2062 Raises: 2063 UnexpectedMethodCallError if the mock_method was not in the group. 2064 """ 2065 2066 # Check to see if this method exists, and if so add it to the set of 2067 # called methods. 2068 for method in self._methods: 2069 if method == mock_method: 2070 self._methods_left.discard(method) 2071 # Always put this group back on top of the queue, 2072 # because we don't know when we are done. 2073 mock_method._call_queue.appendleft(self) 2074 return self, method 2075 2076 if self.IsSatisfied(): 2077 next_method = mock_method._PopNextMethod() 2078 return next_method, None 2079 else: 2080 raise UnexpectedMethodCallError(mock_method, self) 2081 2082 def IsSatisfied(self): 2083 """Return True if all methods in group are called at least once.""" 2084 return len(self._methods_left) == 0 2085 2086 2087class MoxMetaTestBase(type): 2088 """Metaclass to add mox cleanup and verification to every test. 2089 2090 As the mox unit testing class is being constructed (MoxTestBase or a 2091 subclass), this metaclass will modify all test functions to call the 2092 CleanUpMox method of the test class after they finish. This means that 2093 unstubbing and verifying will happen for every test with no additional 2094 code, and any failures will result in test failures as opposed to errors. 2095 """ 2096 2097 def __init__(cls, name, bases, d): 2098 type.__init__(cls, name, bases, d) 2099 2100 # also get all the attributes from the base classes to account 2101 # for a case when test class is not the immediate child of MoxTestBase 2102 for base in bases: 2103 for attr_name in dir(base): 2104 if attr_name not in d: 2105 d[attr_name] = getattr(base, attr_name) 2106 2107 for func_name, func in d.items(): 2108 if func_name.startswith('test') and callable(func): 2109 2110 setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func)) 2111 2112 @staticmethod 2113 def CleanUpTest(cls, func): 2114 """Adds Mox cleanup code to any MoxTestBase method. 2115 2116 Always unsets stubs after a test. Will verify all mocks for tests that 2117 otherwise pass. 2118 2119 Args: 2120 cls: MoxTestBase or subclass; the class whose method we are 2121 altering. 2122 func: method; the method of the MoxTestBase test class we wish to 2123 alter. 2124 2125 Returns: 2126 The modified method. 2127 """ 2128 def new_method(self, *args, **kwargs): 2129 mox_obj = getattr(self, 'mox', None) 2130 stubout_obj = getattr(self, 'stubs', None) 2131 cleanup_mox = False 2132 cleanup_stubout = False 2133 if mox_obj and isinstance(mox_obj, Mox): 2134 cleanup_mox = True 2135 if stubout_obj and isinstance(stubout_obj, 2136 stubout.StubOutForTesting): 2137 cleanup_stubout = True 2138 try: 2139 func(self, *args, **kwargs) 2140 finally: 2141 if cleanup_mox: 2142 mox_obj.UnsetStubs() 2143 if cleanup_stubout: 2144 stubout_obj.UnsetAll() 2145 stubout_obj.SmartUnsetAll() 2146 if cleanup_mox: 2147 mox_obj.VerifyAll() 2148 new_method.__name__ = func.__name__ 2149 new_method.__doc__ = func.__doc__ 2150 new_method.__module__ = func.__module__ 2151 return new_method 2152 2153 2154_MoxTestBase = MoxMetaTestBase('_MoxTestBase', (unittest.TestCase, ), {}) 2155 2156 2157class MoxTestBase(_MoxTestBase): 2158 """Convenience test class to make stubbing easier. 2159 2160 Sets up a "mox" attribute which is an instance of Mox (any mox tests will 2161 want this), and a "stubs" attribute that is an instance of 2162 StubOutForTesting (needed at times). Also automatically unsets any stubs 2163 and verifies that all mock methods have been called at the end of each 2164 test, eliminating boilerplate code. 2165 """ 2166 2167 def setUp(self): 2168 super(MoxTestBase, self).setUp() 2169 self.mox = Mox() 2170 self.stubs = stubout.StubOutForTesting() 2171