1# Copyright (c) 2005 Divmod, Inc. 2# Copyright (c) Twisted Matrix Laboratories. 3# See LICENSE for details. 4 5""" 6Tests for L{twisted.protocols.amp}. 7""" 8 9import datetime 10import decimal 11 12from zope.interface import implements 13from zope.interface.verify import verifyClass, verifyObject 14 15from twisted.python import filepath 16from twisted.python.failure import Failure 17from twisted.protocols import amp 18from twisted.trial import unittest 19from twisted.internet import protocol, defer, error, reactor, interfaces 20from twisted.test import iosim 21from twisted.test.proto_helpers import StringTransport 22 23ssl = None 24try: 25 from twisted.internet import ssl 26except ImportError: 27 pass 28 29if ssl and not ssl.supported: 30 ssl = None 31 32if ssl is None: 33 skipSSL = "SSL not available" 34else: 35 skipSSL = None 36 37 38class TestProto(protocol.Protocol): 39 """ 40 A trivial protocol for use in testing where a L{Protocol} is expected. 41 42 @ivar instanceId: the id of this instance 43 @ivar onConnLost: deferred that will fired when the connection is lost 44 @ivar dataToSend: data to send on the protocol 45 """ 46 47 instanceCount = 0 48 49 def __init__(self, onConnLost, dataToSend): 50 self.onConnLost = onConnLost 51 self.dataToSend = dataToSend 52 self.instanceId = TestProto.instanceCount 53 TestProto.instanceCount = TestProto.instanceCount + 1 54 55 56 def connectionMade(self): 57 self.data = [] 58 self.transport.write(self.dataToSend) 59 60 61 def dataReceived(self, bytes): 62 self.data.append(bytes) 63 64 65 def connectionLost(self, reason): 66 self.onConnLost.callback(self.data) 67 68 69 def __repr__(self): 70 """ 71 Custom repr for testing to avoid coupling amp tests with repr from 72 L{Protocol} 73 74 Returns a string which contains a unique identifier that can be looked 75 up using the instanceId property:: 76 77 <TestProto #3> 78 """ 79 return "<TestProto #%d>" % (self.instanceId,) 80 81 82 83class SimpleSymmetricProtocol(amp.AMP): 84 85 def sendHello(self, text): 86 return self.callRemoteString( 87 "hello", 88 hello=text) 89 90 def amp_HELLO(self, box): 91 return amp.Box(hello=box['hello']) 92 93 def amp_HOWDOYOUDO(self, box): 94 return amp.QuitBox(howdoyoudo='world') 95 96 97 98class UnfriendlyGreeting(Exception): 99 """Greeting was insufficiently kind. 100 """ 101 102class DeathThreat(Exception): 103 """Greeting was insufficiently kind. 104 """ 105 106class UnknownProtocol(Exception): 107 """Asked to switch to the wrong protocol. 108 """ 109 110 111class TransportPeer(amp.Argument): 112 # this serves as some informal documentation for how to get variables from 113 # the protocol or your environment and pass them to methods as arguments. 114 def retrieve(self, d, name, proto): 115 return '' 116 117 def fromStringProto(self, notAString, proto): 118 return proto.transport.getPeer() 119 120 def toBox(self, name, strings, objects, proto): 121 return 122 123 124 125class Hello(amp.Command): 126 127 commandName = 'hello' 128 129 arguments = [('hello', amp.String()), 130 ('optional', amp.Boolean(optional=True)), 131 ('print', amp.Unicode(optional=True)), 132 ('from', TransportPeer(optional=True)), 133 ('mixedCase', amp.String(optional=True)), 134 ('dash-arg', amp.String(optional=True)), 135 ('underscore_arg', amp.String(optional=True))] 136 137 response = [('hello', amp.String()), 138 ('print', amp.Unicode(optional=True))] 139 140 errors = {UnfriendlyGreeting: 'UNFRIENDLY'} 141 142 fatalErrors = {DeathThreat: 'DEAD'} 143 144class NoAnswerHello(Hello): 145 commandName = Hello.commandName 146 requiresAnswer = False 147 148class FutureHello(amp.Command): 149 commandName = 'hello' 150 151 arguments = [('hello', amp.String()), 152 ('optional', amp.Boolean(optional=True)), 153 ('print', amp.Unicode(optional=True)), 154 ('from', TransportPeer(optional=True)), 155 ('bonus', amp.String(optional=True)), # addt'l arguments 156 # should generally be 157 # added at the end, and 158 # be optional... 159 ] 160 161 response = [('hello', amp.String()), 162 ('print', amp.Unicode(optional=True))] 163 164 errors = {UnfriendlyGreeting: 'UNFRIENDLY'} 165 166class WTF(amp.Command): 167 """ 168 An example of an invalid command. 169 """ 170 171 172class BrokenReturn(amp.Command): 173 """ An example of a perfectly good command, but the handler is going to return 174 None... 175 """ 176 177 commandName = 'broken_return' 178 179class Goodbye(amp.Command): 180 # commandName left blank on purpose: this tests implicit command names. 181 response = [('goodbye', amp.String())] 182 responseType = amp.QuitBox 183 184class Howdoyoudo(amp.Command): 185 commandName = 'howdoyoudo' 186 # responseType = amp.QuitBox 187 188class WaitForever(amp.Command): 189 commandName = 'wait_forever' 190 191class GetList(amp.Command): 192 commandName = 'getlist' 193 arguments = [('length', amp.Integer())] 194 response = [('body', amp.AmpList([('x', amp.Integer())]))] 195 196class DontRejectMe(amp.Command): 197 commandName = 'dontrejectme' 198 arguments = [ 199 ('magicWord', amp.Unicode()), 200 ('list', amp.AmpList([('name', amp.Unicode())], optional=True)), 201 ] 202 response = [('response', amp.Unicode())] 203 204class SecuredPing(amp.Command): 205 # XXX TODO: actually make this refuse to send over an insecure connection 206 response = [('pinged', amp.Boolean())] 207 208class TestSwitchProto(amp.ProtocolSwitchCommand): 209 commandName = 'Switch-Proto' 210 211 arguments = [ 212 ('name', amp.String()), 213 ] 214 errors = {UnknownProtocol: 'UNKNOWN'} 215 216class SingleUseFactory(protocol.ClientFactory): 217 def __init__(self, proto): 218 self.proto = proto 219 self.proto.factory = self 220 221 def buildProtocol(self, addr): 222 p, self.proto = self.proto, None 223 return p 224 225 reasonFailed = None 226 227 def clientConnectionFailed(self, connector, reason): 228 self.reasonFailed = reason 229 return 230 231THING_I_DONT_UNDERSTAND = 'gwebol nargo' 232class ThingIDontUnderstandError(Exception): 233 pass 234 235class FactoryNotifier(amp.AMP): 236 factory = None 237 def connectionMade(self): 238 if self.factory is not None: 239 self.factory.theProto = self 240 if hasattr(self.factory, 'onMade'): 241 self.factory.onMade.callback(None) 242 243 def emitpong(self): 244 from twisted.internet.interfaces import ISSLTransport 245 if not ISSLTransport.providedBy(self.transport): 246 raise DeathThreat("only send secure pings over secure channels") 247 return {'pinged': True} 248 SecuredPing.responder(emitpong) 249 250 251class SimpleSymmetricCommandProtocol(FactoryNotifier): 252 maybeLater = None 253 def __init__(self, onConnLost=None): 254 amp.AMP.__init__(self) 255 self.onConnLost = onConnLost 256 257 def sendHello(self, text): 258 return self.callRemote(Hello, hello=text) 259 260 def sendUnicodeHello(self, text, translation): 261 return self.callRemote(Hello, hello=text, Print=translation) 262 263 greeted = False 264 265 def cmdHello(self, hello, From, optional=None, Print=None, 266 mixedCase=None, dash_arg=None, underscore_arg=None): 267 assert From == self.transport.getPeer() 268 if hello == THING_I_DONT_UNDERSTAND: 269 raise ThingIDontUnderstandError() 270 if hello.startswith('fuck'): 271 raise UnfriendlyGreeting("Don't be a dick.") 272 if hello == 'die': 273 raise DeathThreat("aieeeeeeeee") 274 result = dict(hello=hello) 275 if Print is not None: 276 result.update(dict(Print=Print)) 277 self.greeted = True 278 return result 279 Hello.responder(cmdHello) 280 281 def cmdGetlist(self, length): 282 return {'body': [dict(x=1)] * length} 283 GetList.responder(cmdGetlist) 284 285 def okiwont(self, magicWord, list=None): 286 if list is None: 287 response = u'list omitted' 288 else: 289 response = u'%s accepted' % (list[0]['name']) 290 return dict(response=response) 291 DontRejectMe.responder(okiwont) 292 293 def waitforit(self): 294 self.waiting = defer.Deferred() 295 return self.waiting 296 WaitForever.responder(waitforit) 297 298 def howdo(self): 299 return dict(howdoyoudo='world') 300 Howdoyoudo.responder(howdo) 301 302 def saybye(self): 303 return dict(goodbye="everyone") 304 Goodbye.responder(saybye) 305 306 def switchToTestProtocol(self, fail=False): 307 if fail: 308 name = 'no-proto' 309 else: 310 name = 'test-proto' 311 p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA) 312 return self.callRemote( 313 TestSwitchProto, 314 SingleUseFactory(p), name=name).addCallback(lambda ign: p) 315 316 def switchit(self, name): 317 if name == 'test-proto': 318 return TestProto(self.onConnLost, SWITCH_SERVER_DATA) 319 raise UnknownProtocol(name) 320 TestSwitchProto.responder(switchit) 321 322 def donothing(self): 323 return None 324 BrokenReturn.responder(donothing) 325 326 327class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol): 328 def switchit(self, name): 329 if name == 'test-proto': 330 self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA) 331 self.maybeLater = defer.Deferred() 332 return self.maybeLater 333 raise UnknownProtocol(name) 334 TestSwitchProto.responder(switchit) 335 336class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol): 337 def badResponder(self, hello, From, optional=None, Print=None, 338 mixedCase=None, dash_arg=None, underscore_arg=None): 339 """ 340 This responder does nothing and forgets to return a dictionary. 341 """ 342 NoAnswerHello.responder(badResponder) 343 344class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol): 345 def goodNoAnswerResponder(self, hello, From, optional=None, Print=None, 346 mixedCase=None, dash_arg=None, underscore_arg=None): 347 return dict(hello=hello+"-noanswer") 348 NoAnswerHello.responder(goodNoAnswerResponder) 349 350def connectedServerAndClient(ServerClass=SimpleSymmetricProtocol, 351 ClientClass=SimpleSymmetricProtocol, 352 *a, **kw): 353 """Returns a 3-tuple: (client, server, pump) 354 """ 355 return iosim.connectedServerAndClient( 356 ServerClass, ClientClass, 357 *a, **kw) 358 359class TotallyDumbProtocol(protocol.Protocol): 360 buf = '' 361 def dataReceived(self, data): 362 self.buf += data 363 364class LiteralAmp(amp.AMP): 365 def __init__(self): 366 self.boxes = [] 367 368 def ampBoxReceived(self, box): 369 self.boxes.append(box) 370 return 371 372 373 374class AmpBoxTests(unittest.TestCase): 375 """ 376 Test a few essential properties of AMP boxes, mostly with respect to 377 serialization correctness. 378 """ 379 380 def test_serializeStr(self): 381 """ 382 Make sure that strs serialize to strs. 383 """ 384 a = amp.AmpBox(key='value') 385 self.assertEqual(type(a.serialize()), str) 386 387 def test_serializeUnicodeKeyRaises(self): 388 """ 389 Verify that TypeError is raised when trying to serialize Unicode keys. 390 """ 391 a = amp.AmpBox(**{u'key': 'value'}) 392 self.assertRaises(TypeError, a.serialize) 393 394 def test_serializeUnicodeValueRaises(self): 395 """ 396 Verify that TypeError is raised when trying to serialize Unicode 397 values. 398 """ 399 a = amp.AmpBox(key=u'value') 400 self.assertRaises(TypeError, a.serialize) 401 402 403 404class ParsingTest(unittest.TestCase): 405 406 def test_booleanValues(self): 407 """ 408 Verify that the Boolean parser parses 'True' and 'False', but nothing 409 else. 410 """ 411 b = amp.Boolean() 412 self.assertEqual(b.fromString("True"), True) 413 self.assertEqual(b.fromString("False"), False) 414 self.assertRaises(TypeError, b.fromString, "ninja") 415 self.assertRaises(TypeError, b.fromString, "true") 416 self.assertRaises(TypeError, b.fromString, "TRUE") 417 self.assertEqual(b.toString(True), 'True') 418 self.assertEqual(b.toString(False), 'False') 419 420 def test_pathValueRoundTrip(self): 421 """ 422 Verify the 'Path' argument can parse and emit a file path. 423 """ 424 fp = filepath.FilePath(self.mktemp()) 425 p = amp.Path() 426 s = p.toString(fp) 427 v = p.fromString(s) 428 self.assertNotIdentical(fp, v) # sanity check 429 self.assertEqual(fp, v) 430 431 432 def test_sillyEmptyThing(self): 433 """ 434 Test that empty boxes raise an error; they aren't supposed to be sent 435 on purpose. 436 """ 437 a = amp.AMP() 438 return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box()) 439 440 441 def test_ParsingRoundTrip(self): 442 """ 443 Verify that various kinds of data make it through the encode/parse 444 round-trip unharmed. 445 """ 446 c, s, p = connectedServerAndClient(ClientClass=LiteralAmp, 447 ServerClass=LiteralAmp) 448 449 SIMPLE = ('simple', 'test') 450 CE = ('ceq', ': ') 451 CR = ('crtest', 'test\r') 452 LF = ('lftest', 'hello\n') 453 NEWLINE = ('newline', 'test\r\none\r\ntwo') 454 NEWLINE2 = ('newline2', 'test\r\none\r\n two') 455 BODYTEST = ('body', 'blah\r\n\r\ntesttest') 456 457 testData = [ 458 [SIMPLE], 459 [SIMPLE, BODYTEST], 460 [SIMPLE, CE], 461 [SIMPLE, CR], 462 [SIMPLE, CE, CR, LF], 463 [CE, CR, LF], 464 [SIMPLE, NEWLINE, CE, NEWLINE2], 465 [BODYTEST, SIMPLE, NEWLINE] 466 ] 467 468 for test in testData: 469 jb = amp.Box() 470 jb.update(dict(test)) 471 jb._sendTo(c) 472 p.flush() 473 self.assertEqual(s.boxes[-1], jb) 474 475 476 477class FakeLocator(object): 478 """ 479 This is a fake implementation of the interface implied by 480 L{CommandLocator}. 481 """ 482 def __init__(self): 483 """ 484 Remember the given keyword arguments as a set of responders. 485 """ 486 self.commands = {} 487 488 489 def locateResponder(self, commandName): 490 """ 491 Look up and return a function passed as a keyword argument of the given 492 name to the constructor. 493 """ 494 return self.commands[commandName] 495 496 497class FakeSender: 498 """ 499 This is a fake implementation of the 'box sender' interface implied by 500 L{AMP}. 501 """ 502 def __init__(self): 503 """ 504 Create a fake sender and initialize the list of received boxes and 505 unhandled errors. 506 """ 507 self.sentBoxes = [] 508 self.unhandledErrors = [] 509 self.expectedErrors = 0 510 511 512 def expectError(self): 513 """ 514 Expect one error, so that the test doesn't fail. 515 """ 516 self.expectedErrors += 1 517 518 519 def sendBox(self, box): 520 """ 521 Accept a box, but don't do anything. 522 """ 523 self.sentBoxes.append(box) 524 525 526 def unhandledError(self, failure): 527 """ 528 Deal with failures by instantly re-raising them for easier debugging. 529 """ 530 self.expectedErrors -= 1 531 if self.expectedErrors < 0: 532 failure.raiseException() 533 else: 534 self.unhandledErrors.append(failure) 535 536 537 538class CommandDispatchTests(unittest.TestCase): 539 """ 540 The AMP CommandDispatcher class dispatches converts AMP boxes into commands 541 and responses using Command.responder decorator. 542 543 Note: Originally, AMP's factoring was such that many tests for this 544 functionality are now implemented as full round-trip tests in L{AMPTest}. 545 Future tests should be written at this level instead, to ensure API 546 compatibility and to provide more granular, readable units of test 547 coverage. 548 """ 549 550 def setUp(self): 551 """ 552 Create a dispatcher to use. 553 """ 554 self.locator = FakeLocator() 555 self.sender = FakeSender() 556 self.dispatcher = amp.BoxDispatcher(self.locator) 557 self.dispatcher.startReceivingBoxes(self.sender) 558 559 560 def test_receivedAsk(self): 561 """ 562 L{CommandDispatcher.ampBoxReceived} should locate the appropriate 563 command in its responder lookup, based on the '_ask' key. 564 """ 565 received = [] 566 def thunk(box): 567 received.append(box) 568 return amp.Box({"hello": "goodbye"}) 569 input = amp.Box(_command="hello", 570 _ask="test-command-id", 571 hello="world") 572 self.locator.commands['hello'] = thunk 573 self.dispatcher.ampBoxReceived(input) 574 self.assertEqual(received, [input]) 575 576 577 def test_sendUnhandledError(self): 578 """ 579 L{CommandDispatcher} should relay its unhandled errors in responding to 580 boxes to its boxSender. 581 """ 582 err = RuntimeError("something went wrong, oh no") 583 self.sender.expectError() 584 self.dispatcher.unhandledError(Failure(err)) 585 self.assertEqual(len(self.sender.unhandledErrors), 1) 586 self.assertEqual(self.sender.unhandledErrors[0].value, err) 587 588 589 def test_unhandledSerializationError(self): 590 """ 591 Errors during serialization ought to be relayed to the sender's 592 unhandledError method. 593 """ 594 err = RuntimeError("something undefined went wrong") 595 def thunk(result): 596 class BrokenBox(amp.Box): 597 def _sendTo(self, proto): 598 raise err 599 return BrokenBox() 600 self.locator.commands['hello'] = thunk 601 input = amp.Box(_command="hello", 602 _ask="test-command-id", 603 hello="world") 604 self.sender.expectError() 605 self.dispatcher.ampBoxReceived(input) 606 self.assertEqual(len(self.sender.unhandledErrors), 1) 607 self.assertEqual(self.sender.unhandledErrors[0].value, err) 608 609 610 def test_callRemote(self): 611 """ 612 L{CommandDispatcher.callRemote} should emit a properly formatted '_ask' 613 box to its boxSender and record an outstanding L{Deferred}. When a 614 corresponding '_answer' packet is received, the L{Deferred} should be 615 fired, and the results translated via the given L{Command}'s response 616 de-serialization. 617 """ 618 D = self.dispatcher.callRemote(Hello, hello='world') 619 self.assertEqual(self.sender.sentBoxes, 620 [amp.AmpBox(_command="hello", 621 _ask="1", 622 hello="world")]) 623 answers = [] 624 D.addCallback(answers.append) 625 self.assertEqual(answers, []) 626 self.dispatcher.ampBoxReceived(amp.AmpBox({'hello': "yay", 627 'print': "ignored", 628 '_answer': "1"})) 629 self.assertEqual(answers, [dict(hello="yay", 630 Print=u"ignored")]) 631 632 633 def _localCallbackErrorLoggingTest(self, callResult): 634 """ 635 Verify that C{callResult} completes with a C{None} result and that an 636 unhandled error has been logged. 637 """ 638 finalResult = [] 639 callResult.addBoth(finalResult.append) 640 641 self.assertEqual(1, len(self.sender.unhandledErrors)) 642 self.assertIsInstance( 643 self.sender.unhandledErrors[0].value, ZeroDivisionError) 644 645 self.assertEqual([None], finalResult) 646 647 648 def test_callRemoteSuccessLocalCallbackErrorLogging(self): 649 """ 650 If the last callback on the L{Deferred} returned by C{callRemote} (added 651 by application code calling C{callRemote}) fails, the failure is passed 652 to the sender's C{unhandledError} method. 653 """ 654 self.sender.expectError() 655 656 callResult = self.dispatcher.callRemote(Hello, hello='world') 657 callResult.addCallback(lambda result: 1 // 0) 658 659 self.dispatcher.ampBoxReceived(amp.AmpBox({ 660 'hello': "yay", 'print': "ignored", '_answer': "1"})) 661 662 self._localCallbackErrorLoggingTest(callResult) 663 664 665 def test_callRemoteErrorLocalCallbackErrorLogging(self): 666 """ 667 Like L{test_callRemoteSuccessLocalCallbackErrorLogging}, but for the 668 case where the L{Deferred} returned by C{callRemote} fails. 669 """ 670 self.sender.expectError() 671 672 callResult = self.dispatcher.callRemote(Hello, hello='world') 673 callResult.addErrback(lambda result: 1 // 0) 674 675 self.dispatcher.ampBoxReceived(amp.AmpBox({ 676 '_error': '1', '_error_code': 'bugs', 677 '_error_description': 'stuff'})) 678 679 self._localCallbackErrorLoggingTest(callResult) 680 681 682 683class SimpleGreeting(amp.Command): 684 """ 685 A very simple greeting command that uses a few basic argument types. 686 """ 687 commandName = 'simple' 688 arguments = [('greeting', amp.Unicode()), 689 ('cookie', amp.Integer())] 690 response = [('cookieplus', amp.Integer())] 691 692 693 694class TestLocator(amp.CommandLocator): 695 """ 696 A locator which implements a responder to the 'simple' command. 697 """ 698 def __init__(self): 699 self.greetings = [] 700 701 702 def greetingResponder(self, greeting, cookie): 703 self.greetings.append((greeting, cookie)) 704 return dict(cookieplus=cookie + 3) 705 greetingResponder = SimpleGreeting.responder(greetingResponder) 706 707 708 709class OverridingLocator(TestLocator): 710 """ 711 A locator which overrides the responder to the 'simple' command. 712 """ 713 714 def greetingResponder(self, greeting, cookie): 715 """ 716 Return a different cookieplus than L{TestLocator.greetingResponder}. 717 """ 718 self.greetings.append((greeting, cookie)) 719 return dict(cookieplus=cookie + 4) 720 greetingResponder = SimpleGreeting.responder(greetingResponder) 721 722 723 724class InheritingLocator(OverridingLocator): 725 """ 726 This locator should inherit the responder from L{OverridingLocator}. 727 """ 728 729 730 731class OverrideLocatorAMP(amp.AMP): 732 def __init__(self): 733 amp.AMP.__init__(self) 734 self.customResponder = object() 735 self.expectations = {"custom": self.customResponder} 736 self.greetings = [] 737 738 739 def lookupFunction(self, name): 740 """ 741 Override the deprecated lookupFunction function. 742 """ 743 if name in self.expectations: 744 result = self.expectations[name] 745 return result 746 else: 747 return super(OverrideLocatorAMP, self).lookupFunction(name) 748 749 750 def greetingResponder(self, greeting, cookie): 751 self.greetings.append((greeting, cookie)) 752 return dict(cookieplus=cookie + 3) 753 greetingResponder = SimpleGreeting.responder(greetingResponder) 754 755 756 757 758class CommandLocatorTests(unittest.TestCase): 759 """ 760 The CommandLocator should enable users to specify responders to commands as 761 functions that take structured objects, annotated with metadata. 762 """ 763 764 def _checkSimpleGreeting(self, locatorClass, expected): 765 """ 766 Check that a locator of type C{locatorClass} finds a responder 767 for command named I{simple} and that the found responder answers 768 with the C{expected} result to a C{SimpleGreeting<"ni hao", 5>} 769 command. 770 """ 771 locator = locatorClass() 772 responderCallable = locator.locateResponder("simple") 773 result = responderCallable(amp.Box(greeting="ni hao", cookie="5")) 774 def done(values): 775 self.assertEqual(values, amp.AmpBox(cookieplus=str(expected))) 776 return result.addCallback(done) 777 778 779 def test_responderDecorator(self): 780 """ 781 A method on a L{CommandLocator} subclass decorated with a L{Command} 782 subclass's L{responder} decorator should be returned from 783 locateResponder, wrapped in logic to serialize and deserialize its 784 arguments. 785 """ 786 return self._checkSimpleGreeting(TestLocator, 8) 787 788 789 def test_responderOverriding(self): 790 """ 791 L{CommandLocator} subclasses can override a responder inherited from 792 a base class by using the L{Command.responder} decorator to register 793 a new responder method. 794 """ 795 return self._checkSimpleGreeting(OverridingLocator, 9) 796 797 798 def test_responderInheritance(self): 799 """ 800 Responder lookup follows the same rules as normal method lookup 801 rules, particularly with respect to inheritance. 802 """ 803 return self._checkSimpleGreeting(InheritingLocator, 9) 804 805 806 def test_lookupFunctionDeprecatedOverride(self): 807 """ 808 Subclasses which override locateResponder under its old name, 809 lookupFunction, should have the override invoked instead. (This tests 810 an AMP subclass, because in the version of the code that could invoke 811 this deprecated code path, there was no L{CommandLocator}.) 812 """ 813 locator = OverrideLocatorAMP() 814 customResponderObject = self.assertWarns( 815 PendingDeprecationWarning, 816 "Override locateResponder, not lookupFunction.", 817 __file__, lambda : locator.locateResponder("custom")) 818 self.assertEqual(locator.customResponder, customResponderObject) 819 # Make sure upcalling works too 820 normalResponderObject = self.assertWarns( 821 PendingDeprecationWarning, 822 "Override locateResponder, not lookupFunction.", 823 __file__, lambda : locator.locateResponder("simple")) 824 result = normalResponderObject(amp.Box(greeting="ni hao", cookie="5")) 825 def done(values): 826 self.assertEqual(values, amp.AmpBox(cookieplus='8')) 827 return result.addCallback(done) 828 829 830 def test_lookupFunctionDeprecatedInvoke(self): 831 """ 832 Invoking locateResponder under its old name, lookupFunction, should 833 emit a deprecation warning, but do the same thing. 834 """ 835 locator = TestLocator() 836 responderCallable = self.assertWarns( 837 PendingDeprecationWarning, 838 "Call locateResponder, not lookupFunction.", __file__, 839 lambda : locator.lookupFunction("simple")) 840 result = responderCallable(amp.Box(greeting="ni hao", cookie="5")) 841 def done(values): 842 self.assertEqual(values, amp.AmpBox(cookieplus='8')) 843 return result.addCallback(done) 844 845 846 847SWITCH_CLIENT_DATA = 'Success!' 848SWITCH_SERVER_DATA = 'No, really. Success.' 849 850 851class BinaryProtocolTests(unittest.TestCase): 852 """ 853 Tests for L{amp.BinaryBoxProtocol}. 854 855 @ivar _boxSender: After C{startReceivingBoxes} is called, the L{IBoxSender} 856 which was passed to it. 857 """ 858 859 def setUp(self): 860 """ 861 Keep track of all boxes received by this test in its capacity as an 862 L{IBoxReceiver} implementor. 863 """ 864 self.boxes = [] 865 self.data = [] 866 867 868 def startReceivingBoxes(self, sender): 869 """ 870 Implement L{IBoxReceiver.startReceivingBoxes} to just remember the 871 value passed in. 872 """ 873 self._boxSender = sender 874 875 876 def ampBoxReceived(self, box): 877 """ 878 A box was received by the protocol. 879 """ 880 self.boxes.append(box) 881 882 stopReason = None 883 def stopReceivingBoxes(self, reason): 884 """ 885 Record the reason that we stopped receiving boxes. 886 """ 887 self.stopReason = reason 888 889 890 # fake ITransport 891 def getPeer(self): 892 return 'no peer' 893 894 895 def getHost(self): 896 return 'no host' 897 898 899 def write(self, data): 900 self.data.append(data) 901 902 903 def test_startReceivingBoxes(self): 904 """ 905 When L{amp.BinaryBoxProtocol} is connected to a transport, it calls 906 C{startReceivingBoxes} on its L{IBoxReceiver} with itself as the 907 L{IBoxSender} parameter. 908 """ 909 protocol = amp.BinaryBoxProtocol(self) 910 protocol.makeConnection(None) 911 self.assertIdentical(self._boxSender, protocol) 912 913 914 def test_sendBoxInStartReceivingBoxes(self): 915 """ 916 The L{IBoxReceiver} which is started when L{amp.BinaryBoxProtocol} is 917 connected to a transport can call C{sendBox} on the L{IBoxSender} 918 passed to it before C{startReceivingBoxes} returns and have that box 919 sent. 920 """ 921 class SynchronouslySendingReceiver: 922 def startReceivingBoxes(self, sender): 923 sender.sendBox(amp.Box({'foo': 'bar'})) 924 925 transport = StringTransport() 926 protocol = amp.BinaryBoxProtocol(SynchronouslySendingReceiver()) 927 protocol.makeConnection(transport) 928 self.assertEqual( 929 transport.value(), 930 '\x00\x03foo\x00\x03bar\x00\x00') 931 932 933 def test_receiveBoxStateMachine(self): 934 """ 935 When a binary box protocol receives: 936 * a key 937 * a value 938 * an empty string 939 it should emit a box and send it to its boxReceiver. 940 """ 941 a = amp.BinaryBoxProtocol(self) 942 a.stringReceived("hello") 943 a.stringReceived("world") 944 a.stringReceived("") 945 self.assertEqual(self.boxes, [amp.AmpBox(hello="world")]) 946 947 948 def test_firstBoxFirstKeyExcessiveLength(self): 949 """ 950 L{amp.BinaryBoxProtocol} drops its connection if the length prefix for 951 the first a key it receives is larger than 255. 952 """ 953 transport = StringTransport() 954 protocol = amp.BinaryBoxProtocol(self) 955 protocol.makeConnection(transport) 956 protocol.dataReceived('\x01\x00') 957 self.assertTrue(transport.disconnecting) 958 959 960 def test_firstBoxSubsequentKeyExcessiveLength(self): 961 """ 962 L{amp.BinaryBoxProtocol} drops its connection if the length prefix for 963 a subsequent key in the first box it receives is larger than 255. 964 """ 965 transport = StringTransport() 966 protocol = amp.BinaryBoxProtocol(self) 967 protocol.makeConnection(transport) 968 protocol.dataReceived('\x00\x01k\x00\x01v') 969 self.assertFalse(transport.disconnecting) 970 protocol.dataReceived('\x01\x00') 971 self.assertTrue(transport.disconnecting) 972 973 974 def test_subsequentBoxFirstKeyExcessiveLength(self): 975 """ 976 L{amp.BinaryBoxProtocol} drops its connection if the length prefix for 977 the first key in a subsequent box it receives is larger than 255. 978 """ 979 transport = StringTransport() 980 protocol = amp.BinaryBoxProtocol(self) 981 protocol.makeConnection(transport) 982 protocol.dataReceived('\x00\x01k\x00\x01v\x00\x00') 983 self.assertFalse(transport.disconnecting) 984 protocol.dataReceived('\x01\x00') 985 self.assertTrue(transport.disconnecting) 986 987 988 def test_excessiveKeyFailure(self): 989 """ 990 If L{amp.BinaryBoxProtocol} disconnects because it received a key 991 length prefix which was too large, the L{IBoxReceiver}'s 992 C{stopReceivingBoxes} method is called with a L{TooLong} failure. 993 """ 994 protocol = amp.BinaryBoxProtocol(self) 995 protocol.makeConnection(StringTransport()) 996 protocol.dataReceived('\x01\x00') 997 protocol.connectionLost( 998 Failure(error.ConnectionDone("simulated connection done"))) 999 self.stopReason.trap(amp.TooLong) 1000 self.assertTrue(self.stopReason.value.isKey) 1001 self.assertFalse(self.stopReason.value.isLocal) 1002 self.assertIdentical(self.stopReason.value.value, None) 1003 self.assertIdentical(self.stopReason.value.keyName, None) 1004 1005 1006 def test_unhandledErrorWithTransport(self): 1007 """ 1008 L{amp.BinaryBoxProtocol.unhandledError} logs the failure passed to it 1009 and disconnects its transport. 1010 """ 1011 transport = StringTransport() 1012 protocol = amp.BinaryBoxProtocol(self) 1013 protocol.makeConnection(transport) 1014 protocol.unhandledError(Failure(RuntimeError("Fake error"))) 1015 self.assertEqual(1, len(self.flushLoggedErrors(RuntimeError))) 1016 self.assertTrue(transport.disconnecting) 1017 1018 1019 def test_unhandledErrorWithoutTransport(self): 1020 """ 1021 L{amp.BinaryBoxProtocol.unhandledError} completes without error when 1022 there is no associated transport. 1023 """ 1024 protocol = amp.BinaryBoxProtocol(self) 1025 protocol.makeConnection(StringTransport()) 1026 protocol.connectionLost(Failure(Exception("Simulated"))) 1027 protocol.unhandledError(Failure(RuntimeError("Fake error"))) 1028 self.assertEqual(1, len(self.flushLoggedErrors(RuntimeError))) 1029 1030 1031 def test_receiveBoxData(self): 1032 """ 1033 When a binary box protocol receives the serialized form of an AMP box, 1034 it should emit a similar box to its boxReceiver. 1035 """ 1036 a = amp.BinaryBoxProtocol(self) 1037 a.dataReceived(amp.Box({"testKey": "valueTest", 1038 "anotherKey": "anotherValue"}).serialize()) 1039 self.assertEqual(self.boxes, 1040 [amp.Box({"testKey": "valueTest", 1041 "anotherKey": "anotherValue"})]) 1042 1043 1044 def test_receiveLongerBoxData(self): 1045 """ 1046 An L{amp.BinaryBoxProtocol} can receive serialized AMP boxes with 1047 values of up to (2 ** 16 - 1) bytes. 1048 """ 1049 length = (2 ** 16 - 1) 1050 value = 'x' * length 1051 transport = StringTransport() 1052 protocol = amp.BinaryBoxProtocol(self) 1053 protocol.makeConnection(transport) 1054 protocol.dataReceived(amp.Box({'k': value}).serialize()) 1055 self.assertEqual(self.boxes, [amp.Box({'k': value})]) 1056 self.assertFalse(transport.disconnecting) 1057 1058 1059 def test_sendBox(self): 1060 """ 1061 When a binary box protocol sends a box, it should emit the serialized 1062 bytes of that box to its transport. 1063 """ 1064 a = amp.BinaryBoxProtocol(self) 1065 a.makeConnection(self) 1066 aBox = amp.Box({"testKey": "valueTest", 1067 "someData": "hello"}) 1068 a.makeConnection(self) 1069 a.sendBox(aBox) 1070 self.assertEqual(''.join(self.data), aBox.serialize()) 1071 1072 1073 def test_connectionLostStopSendingBoxes(self): 1074 """ 1075 When a binary box protocol loses its connection, it should notify its 1076 box receiver that it has stopped receiving boxes. 1077 """ 1078 a = amp.BinaryBoxProtocol(self) 1079 a.makeConnection(self) 1080 connectionFailure = Failure(RuntimeError()) 1081 a.connectionLost(connectionFailure) 1082 self.assertIdentical(self.stopReason, connectionFailure) 1083 1084 1085 def test_protocolSwitch(self): 1086 """ 1087 L{BinaryBoxProtocol} has the capacity to switch to a different protocol 1088 on a box boundary. When a protocol is in the process of switching, it 1089 cannot receive traffic. 1090 """ 1091 otherProto = TestProto(None, "outgoing data") 1092 test = self 1093 class SwitchyReceiver: 1094 switched = False 1095 def startReceivingBoxes(self, sender): 1096 pass 1097 def ampBoxReceived(self, box): 1098 test.assertFalse(self.switched, 1099 "Should only receive one box!") 1100 self.switched = True 1101 a._lockForSwitch() 1102 a._switchTo(otherProto) 1103 a = amp.BinaryBoxProtocol(SwitchyReceiver()) 1104 anyOldBox = amp.Box({"include": "lots", 1105 "of": "data"}) 1106 a.makeConnection(self) 1107 # Include a 0-length box at the beginning of the next protocol's data, 1108 # to make sure that AMP doesn't eat the data or try to deliver extra 1109 # boxes either... 1110 moreThanOneBox = anyOldBox.serialize() + "\x00\x00Hello, world!" 1111 a.dataReceived(moreThanOneBox) 1112 self.assertIdentical(otherProto.transport, self) 1113 self.assertEqual("".join(otherProto.data), "\x00\x00Hello, world!") 1114 self.assertEqual(self.data, ["outgoing data"]) 1115 a.dataReceived("more data") 1116 self.assertEqual("".join(otherProto.data), 1117 "\x00\x00Hello, world!more data") 1118 self.assertRaises(amp.ProtocolSwitched, a.sendBox, anyOldBox) 1119 1120 1121 def test_protocolSwitchEmptyBuffer(self): 1122 """ 1123 After switching to a different protocol, if no extra bytes beyond 1124 the switch box were delivered, an empty string is not passed to the 1125 switched protocol's C{dataReceived} method. 1126 """ 1127 a = amp.BinaryBoxProtocol(self) 1128 a.makeConnection(self) 1129 otherProto = TestProto(None, "") 1130 a._switchTo(otherProto) 1131 self.assertEqual(otherProto.data, []) 1132 1133 1134 def test_protocolSwitchInvalidStates(self): 1135 """ 1136 In order to make sure the protocol never gets any invalid data sent 1137 into the middle of a box, it must be locked for switching before it is 1138 switched. It can only be unlocked if the switch failed, and attempting 1139 to send a box while it is locked should raise an exception. 1140 """ 1141 a = amp.BinaryBoxProtocol(self) 1142 a.makeConnection(self) 1143 sampleBox = amp.Box({"some": "data"}) 1144 a._lockForSwitch() 1145 self.assertRaises(amp.ProtocolSwitched, a.sendBox, sampleBox) 1146 a._unlockFromSwitch() 1147 a.sendBox(sampleBox) 1148 self.assertEqual(''.join(self.data), sampleBox.serialize()) 1149 a._lockForSwitch() 1150 otherProto = TestProto(None, "outgoing data") 1151 a._switchTo(otherProto) 1152 self.assertRaises(amp.ProtocolSwitched, a._unlockFromSwitch) 1153 1154 1155 def test_protocolSwitchLoseConnection(self): 1156 """ 1157 When the protocol is switched, it should notify its nested protocol of 1158 disconnection. 1159 """ 1160 class Loser(protocol.Protocol): 1161 reason = None 1162 def connectionLost(self, reason): 1163 self.reason = reason 1164 connectionLoser = Loser() 1165 a = amp.BinaryBoxProtocol(self) 1166 a.makeConnection(self) 1167 a._lockForSwitch() 1168 a._switchTo(connectionLoser) 1169 connectionFailure = Failure(RuntimeError()) 1170 a.connectionLost(connectionFailure) 1171 self.assertEqual(connectionLoser.reason, connectionFailure) 1172 1173 1174 def test_protocolSwitchLoseClientConnection(self): 1175 """ 1176 When the protocol is switched, it should notify its nested client 1177 protocol factory of disconnection. 1178 """ 1179 class ClientLoser: 1180 reason = None 1181 def clientConnectionLost(self, connector, reason): 1182 self.reason = reason 1183 a = amp.BinaryBoxProtocol(self) 1184 connectionLoser = protocol.Protocol() 1185 clientLoser = ClientLoser() 1186 a.makeConnection(self) 1187 a._lockForSwitch() 1188 a._switchTo(connectionLoser, clientLoser) 1189 connectionFailure = Failure(RuntimeError()) 1190 a.connectionLost(connectionFailure) 1191 self.assertEqual(clientLoser.reason, connectionFailure) 1192 1193 1194 1195class AMPTest(unittest.TestCase): 1196 1197 def test_interfaceDeclarations(self): 1198 """ 1199 The classes in the amp module ought to implement the interfaces that 1200 are declared for their benefit. 1201 """ 1202 for interface, implementation in [(amp.IBoxSender, amp.BinaryBoxProtocol), 1203 (amp.IBoxReceiver, amp.BoxDispatcher), 1204 (amp.IResponderLocator, amp.CommandLocator), 1205 (amp.IResponderLocator, amp.SimpleStringLocator), 1206 (amp.IBoxSender, amp.AMP), 1207 (amp.IBoxReceiver, amp.AMP), 1208 (amp.IResponderLocator, amp.AMP)]: 1209 self.failUnless(interface.implementedBy(implementation), 1210 "%s does not implements(%s)" % (implementation, interface)) 1211 1212 1213 def test_helloWorld(self): 1214 """ 1215 Verify that a simple command can be sent and its response received with 1216 the simple low-level string-based API. 1217 """ 1218 c, s, p = connectedServerAndClient() 1219 L = [] 1220 HELLO = 'world' 1221 c.sendHello(HELLO).addCallback(L.append) 1222 p.flush() 1223 self.assertEqual(L[0]['hello'], HELLO) 1224 1225 1226 def test_wireFormatRoundTrip(self): 1227 """ 1228 Verify that mixed-case, underscored and dashed arguments are mapped to 1229 their python names properly. 1230 """ 1231 c, s, p = connectedServerAndClient() 1232 L = [] 1233 HELLO = 'world' 1234 c.sendHello(HELLO).addCallback(L.append) 1235 p.flush() 1236 self.assertEqual(L[0]['hello'], HELLO) 1237 1238 1239 def test_helloWorldUnicode(self): 1240 """ 1241 Verify that unicode arguments can be encoded and decoded. 1242 """ 1243 c, s, p = connectedServerAndClient( 1244 ServerClass=SimpleSymmetricCommandProtocol, 1245 ClientClass=SimpleSymmetricCommandProtocol) 1246 L = [] 1247 HELLO = 'world' 1248 HELLO_UNICODE = 'wor\u1234ld' 1249 c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append) 1250 p.flush() 1251 self.assertEqual(L[0]['hello'], HELLO) 1252 self.assertEqual(L[0]['Print'], HELLO_UNICODE) 1253 1254 1255 def test_callRemoteStringRequiresAnswerFalse(self): 1256 """ 1257 L{BoxDispatcher.callRemoteString} returns C{None} if C{requiresAnswer} 1258 is C{False}. 1259 """ 1260 c, s, p = connectedServerAndClient() 1261 ret = c.callRemoteString("WTF", requiresAnswer=False) 1262 self.assertIdentical(ret, None) 1263 1264 1265 def test_unknownCommandLow(self): 1266 """ 1267 Verify that unknown commands using low-level APIs will be rejected with an 1268 error, but will NOT terminate the connection. 1269 """ 1270 c, s, p = connectedServerAndClient() 1271 L = [] 1272 def clearAndAdd(e): 1273 """ 1274 You can't propagate the error... 1275 """ 1276 e.trap(amp.UnhandledCommand) 1277 return "OK" 1278 c.callRemoteString("WTF").addErrback(clearAndAdd).addCallback(L.append) 1279 p.flush() 1280 self.assertEqual(L.pop(), "OK") 1281 HELLO = 'world' 1282 c.sendHello(HELLO).addCallback(L.append) 1283 p.flush() 1284 self.assertEqual(L[0]['hello'], HELLO) 1285 1286 1287 def test_unknownCommandHigh(self): 1288 """ 1289 Verify that unknown commands using high-level APIs will be rejected with an 1290 error, but will NOT terminate the connection. 1291 """ 1292 c, s, p = connectedServerAndClient() 1293 L = [] 1294 def clearAndAdd(e): 1295 """ 1296 You can't propagate the error... 1297 """ 1298 e.trap(amp.UnhandledCommand) 1299 return "OK" 1300 c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append) 1301 p.flush() 1302 self.assertEqual(L.pop(), "OK") 1303 HELLO = 'world' 1304 c.sendHello(HELLO).addCallback(L.append) 1305 p.flush() 1306 self.assertEqual(L[0]['hello'], HELLO) 1307 1308 1309 def test_brokenReturnValue(self): 1310 """ 1311 It can be very confusing if you write some code which responds to a 1312 command, but gets the return value wrong. Most commonly you end up 1313 returning None instead of a dictionary. 1314 1315 Verify that if that happens, the framework logs a useful error. 1316 """ 1317 L = [] 1318 SimpleSymmetricCommandProtocol().dispatchCommand( 1319 amp.AmpBox(_command=BrokenReturn.commandName)).addErrback(L.append) 1320 L[0].trap(amp.BadLocalReturn) 1321 self.failUnlessIn('None', repr(L[0].value)) 1322 1323 1324 def test_unknownArgument(self): 1325 """ 1326 Verify that unknown arguments are ignored, and not passed to a Python 1327 function which can't accept them. 1328 """ 1329 c, s, p = connectedServerAndClient( 1330 ServerClass=SimpleSymmetricCommandProtocol, 1331 ClientClass=SimpleSymmetricCommandProtocol) 1332 L = [] 1333 HELLO = 'world' 1334 # c.sendHello(HELLO).addCallback(L.append) 1335 c.callRemote(FutureHello, 1336 hello=HELLO, 1337 bonus="I'm not in the book!").addCallback( 1338 L.append) 1339 p.flush() 1340 self.assertEqual(L[0]['hello'], HELLO) 1341 1342 1343 def test_simpleReprs(self): 1344 """ 1345 Verify that the various Box objects repr properly, for debugging. 1346 """ 1347 self.assertEqual(type(repr(amp._SwitchBox('a'))), str) 1348 self.assertEqual(type(repr(amp.QuitBox())), str) 1349 self.assertEqual(type(repr(amp.AmpBox())), str) 1350 self.failUnless("AmpBox" in repr(amp.AmpBox())) 1351 1352 1353 def test_innerProtocolInRepr(self): 1354 """ 1355 Verify that L{AMP} objects output their innerProtocol when set. 1356 """ 1357 otherProto = TestProto(None, "outgoing data") 1358 a = amp.AMP() 1359 a.innerProtocol = otherProto 1360 1361 self.assertEqual( 1362 repr(a), "<AMP inner <TestProto #%d> at 0x%x>" % ( 1363 otherProto.instanceId, id(a))) 1364 1365 1366 def test_innerProtocolNotInRepr(self): 1367 """ 1368 Verify that L{AMP} objects do not output 'inner' when no innerProtocol 1369 is set. 1370 """ 1371 a = amp.AMP() 1372 self.assertEqual(repr(a), "<AMP at 0x%x>" % (id(a),)) 1373 1374 1375 def test_simpleSSLRepr(self): 1376 """ 1377 L{amp._TLSBox.__repr__} returns a string. 1378 """ 1379 self.assertEqual(type(repr(amp._TLSBox())), str) 1380 1381 test_simpleSSLRepr.skip = skipSSL 1382 1383 1384 def test_keyTooLong(self): 1385 """ 1386 Verify that a key that is too long will immediately raise a synchronous 1387 exception. 1388 """ 1389 c, s, p = connectedServerAndClient() 1390 x = "H" * (0xff+1) 1391 tl = self.assertRaises(amp.TooLong, 1392 c.callRemoteString, "Hello", 1393 **{x: "hi"}) 1394 self.assertTrue(tl.isKey) 1395 self.assertTrue(tl.isLocal) 1396 self.assertIdentical(tl.keyName, None) 1397 self.assertEqual(tl.value, x) 1398 self.assertIn(str(len(x)), repr(tl)) 1399 self.assertIn("key", repr(tl)) 1400 1401 1402 def test_valueTooLong(self): 1403 """ 1404 Verify that attempting to send value longer than 64k will immediately 1405 raise an exception. 1406 """ 1407 c, s, p = connectedServerAndClient() 1408 x = "H" * (0xffff+1) 1409 tl = self.assertRaises(amp.TooLong, c.sendHello, x) 1410 p.flush() 1411 self.failIf(tl.isKey) 1412 self.failUnless(tl.isLocal) 1413 self.assertEqual(tl.keyName, 'hello') 1414 self.failUnlessIdentical(tl.value, x) 1415 self.failUnless(str(len(x)) in repr(tl)) 1416 self.failUnless("value" in repr(tl)) 1417 self.failUnless('hello' in repr(tl)) 1418 1419 1420 def test_helloWorldCommand(self): 1421 """ 1422 Verify that a simple command can be sent and its response received with 1423 the high-level value parsing API. 1424 """ 1425 c, s, p = connectedServerAndClient( 1426 ServerClass=SimpleSymmetricCommandProtocol, 1427 ClientClass=SimpleSymmetricCommandProtocol) 1428 L = [] 1429 HELLO = 'world' 1430 c.sendHello(HELLO).addCallback(L.append) 1431 p.flush() 1432 self.assertEqual(L[0]['hello'], HELLO) 1433 1434 1435 def test_helloErrorHandling(self): 1436 """ 1437 Verify that if a known error type is raised and handled, it will be 1438 properly relayed to the other end of the connection and translated into 1439 an exception, and no error will be logged. 1440 """ 1441 L=[] 1442 c, s, p = connectedServerAndClient( 1443 ServerClass=SimpleSymmetricCommandProtocol, 1444 ClientClass=SimpleSymmetricCommandProtocol) 1445 HELLO = 'fuck you' 1446 c.sendHello(HELLO).addErrback(L.append) 1447 p.flush() 1448 L[0].trap(UnfriendlyGreeting) 1449 self.assertEqual(str(L[0].value), "Don't be a dick.") 1450 1451 1452 def test_helloFatalErrorHandling(self): 1453 """ 1454 Verify that if a known, fatal error type is raised and handled, it will 1455 be properly relayed to the other end of the connection and translated 1456 into an exception, no error will be logged, and the connection will be 1457 terminated. 1458 """ 1459 L=[] 1460 c, s, p = connectedServerAndClient( 1461 ServerClass=SimpleSymmetricCommandProtocol, 1462 ClientClass=SimpleSymmetricCommandProtocol) 1463 HELLO = 'die' 1464 c.sendHello(HELLO).addErrback(L.append) 1465 p.flush() 1466 L.pop().trap(DeathThreat) 1467 c.sendHello(HELLO).addErrback(L.append) 1468 p.flush() 1469 L.pop().trap(error.ConnectionDone) 1470 1471 1472 1473 def test_helloNoErrorHandling(self): 1474 """ 1475 Verify that if an unknown error type is raised, it will be relayed to 1476 the other end of the connection and translated into an exception, it 1477 will be logged, and then the connection will be dropped. 1478 """ 1479 L=[] 1480 c, s, p = connectedServerAndClient( 1481 ServerClass=SimpleSymmetricCommandProtocol, 1482 ClientClass=SimpleSymmetricCommandProtocol) 1483 HELLO = THING_I_DONT_UNDERSTAND 1484 c.sendHello(HELLO).addErrback(L.append) 1485 p.flush() 1486 ure = L.pop() 1487 ure.trap(amp.UnknownRemoteError) 1488 c.sendHello(HELLO).addErrback(L.append) 1489 cl = L.pop() 1490 cl.trap(error.ConnectionDone) 1491 # The exception should have been logged. 1492 self.failUnless(self.flushLoggedErrors(ThingIDontUnderstandError)) 1493 1494 1495 1496 def test_lateAnswer(self): 1497 """ 1498 Verify that a command that does not get answered until after the 1499 connection terminates will not cause any errors. 1500 """ 1501 c, s, p = connectedServerAndClient( 1502 ServerClass=SimpleSymmetricCommandProtocol, 1503 ClientClass=SimpleSymmetricCommandProtocol) 1504 L = [] 1505 c.callRemote(WaitForever).addErrback(L.append) 1506 p.flush() 1507 self.assertEqual(L, []) 1508 s.transport.loseConnection() 1509 p.flush() 1510 L.pop().trap(error.ConnectionDone) 1511 # Just make sure that it doesn't error... 1512 s.waiting.callback({}) 1513 return s.waiting 1514 1515 1516 def test_requiresNoAnswer(self): 1517 """ 1518 Verify that a command that requires no answer is run. 1519 """ 1520 c, s, p = connectedServerAndClient( 1521 ServerClass=SimpleSymmetricCommandProtocol, 1522 ClientClass=SimpleSymmetricCommandProtocol) 1523 HELLO = 'world' 1524 c.callRemote(NoAnswerHello, hello=HELLO) 1525 p.flush() 1526 self.failUnless(s.greeted) 1527 1528 1529 def test_requiresNoAnswerFail(self): 1530 """ 1531 Verify that commands sent after a failed no-answer request do not complete. 1532 """ 1533 L=[] 1534 c, s, p = connectedServerAndClient( 1535 ServerClass=SimpleSymmetricCommandProtocol, 1536 ClientClass=SimpleSymmetricCommandProtocol) 1537 HELLO = 'fuck you' 1538 c.callRemote(NoAnswerHello, hello=HELLO) 1539 p.flush() 1540 # This should be logged locally. 1541 self.failUnless(self.flushLoggedErrors(amp.RemoteAmpError)) 1542 HELLO = 'world' 1543 c.callRemote(Hello, hello=HELLO).addErrback(L.append) 1544 p.flush() 1545 L.pop().trap(error.ConnectionDone) 1546 self.failIf(s.greeted) 1547 1548 1549 def test_noAnswerResponderBadAnswer(self): 1550 """ 1551 Verify that responders of requiresAnswer=False commands have to return 1552 a dictionary anyway. 1553 1554 (requiresAnswer is a hint from the _client_ - the server may be called 1555 upon to answer commands in any case, if the client wants to know when 1556 they complete.) 1557 """ 1558 c, s, p = connectedServerAndClient( 1559 ServerClass=BadNoAnswerCommandProtocol, 1560 ClientClass=SimpleSymmetricCommandProtocol) 1561 c.callRemote(NoAnswerHello, hello="hello") 1562 p.flush() 1563 le = self.flushLoggedErrors(amp.BadLocalReturn) 1564 self.assertEqual(len(le), 1) 1565 1566 1567 def test_noAnswerResponderAskedForAnswer(self): 1568 """ 1569 Verify that responders with requiresAnswer=False will actually respond 1570 if the client sets requiresAnswer=True. In other words, verify that 1571 requiresAnswer is a hint honored only by the client. 1572 """ 1573 c, s, p = connectedServerAndClient( 1574 ServerClass=NoAnswerCommandProtocol, 1575 ClientClass=SimpleSymmetricCommandProtocol) 1576 L = [] 1577 c.callRemote(Hello, hello="Hello!").addCallback(L.append) 1578 p.flush() 1579 self.assertEqual(len(L), 1) 1580 self.assertEqual(L, [dict(hello="Hello!-noanswer", 1581 Print=None)]) # Optional response argument 1582 1583 1584 def test_ampListCommand(self): 1585 """ 1586 Test encoding of an argument that uses the AmpList encoding. 1587 """ 1588 c, s, p = connectedServerAndClient( 1589 ServerClass=SimpleSymmetricCommandProtocol, 1590 ClientClass=SimpleSymmetricCommandProtocol) 1591 L = [] 1592 c.callRemote(GetList, length=10).addCallback(L.append) 1593 p.flush() 1594 values = L.pop().get('body') 1595 self.assertEqual(values, [{'x': 1}] * 10) 1596 1597 1598 def test_optionalAmpListOmitted(self): 1599 """ 1600 Sending a command with an omitted AmpList argument that is 1601 designated as optional does not raise an InvalidSignature error. 1602 """ 1603 c, s, p = connectedServerAndClient( 1604 ServerClass=SimpleSymmetricCommandProtocol, 1605 ClientClass=SimpleSymmetricCommandProtocol) 1606 L = [] 1607 c.callRemote(DontRejectMe, magicWord=u'please').addCallback(L.append) 1608 p.flush() 1609 response = L.pop().get('response') 1610 self.assertEqual(response, 'list omitted') 1611 1612 1613 def test_optionalAmpListPresent(self): 1614 """ 1615 Sanity check that optional AmpList arguments are processed normally. 1616 """ 1617 c, s, p = connectedServerAndClient( 1618 ServerClass=SimpleSymmetricCommandProtocol, 1619 ClientClass=SimpleSymmetricCommandProtocol) 1620 L = [] 1621 c.callRemote(DontRejectMe, magicWord=u'please', 1622 list=[{'name': 'foo'}]).addCallback(L.append) 1623 p.flush() 1624 response = L.pop().get('response') 1625 self.assertEqual(response, 'foo accepted') 1626 1627 1628 def test_failEarlyOnArgSending(self): 1629 """ 1630 Verify that if we pass an invalid argument list (omitting an argument), 1631 an exception will be raised. 1632 """ 1633 self.assertRaises(amp.InvalidSignature, Hello) 1634 1635 1636 def test_doubleProtocolSwitch(self): 1637 """ 1638 As a debugging aid, a protocol system should raise a 1639 L{ProtocolSwitched} exception when asked to switch a protocol that is 1640 already switched. 1641 """ 1642 serverDeferred = defer.Deferred() 1643 serverProto = SimpleSymmetricCommandProtocol(serverDeferred) 1644 clientDeferred = defer.Deferred() 1645 clientProto = SimpleSymmetricCommandProtocol(clientDeferred) 1646 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto, 1647 ClientClass=lambda: clientProto) 1648 def switched(result): 1649 self.assertRaises(amp.ProtocolSwitched, c.switchToTestProtocol) 1650 self.testSucceeded = True 1651 c.switchToTestProtocol().addCallback(switched) 1652 p.flush() 1653 self.failUnless(self.testSucceeded) 1654 1655 1656 def test_protocolSwitch(self, switcher=SimpleSymmetricCommandProtocol, 1657 spuriousTraffic=False, 1658 spuriousError=False): 1659 """ 1660 Verify that it is possible to switch to another protocol mid-connection and 1661 send data to it successfully. 1662 """ 1663 self.testSucceeded = False 1664 1665 serverDeferred = defer.Deferred() 1666 serverProto = switcher(serverDeferred) 1667 clientDeferred = defer.Deferred() 1668 clientProto = switcher(clientDeferred) 1669 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto, 1670 ClientClass=lambda: clientProto) 1671 1672 if spuriousTraffic: 1673 wfdr = [] # remote 1674 c.callRemote(WaitForever).addErrback(wfdr.append) 1675 switchDeferred = c.switchToTestProtocol() 1676 if spuriousTraffic: 1677 self.assertRaises(amp.ProtocolSwitched, c.sendHello, 'world') 1678 1679 def cbConnsLost(((serverSuccess, serverData), 1680 (clientSuccess, clientData))): 1681 self.failUnless(serverSuccess) 1682 self.failUnless(clientSuccess) 1683 self.assertEqual(''.join(serverData), SWITCH_CLIENT_DATA) 1684 self.assertEqual(''.join(clientData), SWITCH_SERVER_DATA) 1685 self.testSucceeded = True 1686 1687 def cbSwitch(proto): 1688 return defer.DeferredList( 1689 [serverDeferred, clientDeferred]).addCallback(cbConnsLost) 1690 1691 switchDeferred.addCallback(cbSwitch) 1692 p.flush() 1693 if serverProto.maybeLater is not None: 1694 serverProto.maybeLater.callback(serverProto.maybeLaterProto) 1695 p.flush() 1696 if spuriousTraffic: 1697 # switch is done here; do this here to make sure that if we're 1698 # going to corrupt the connection, we do it before it's closed. 1699 if spuriousError: 1700 s.waiting.errback(amp.RemoteAmpError( 1701 "SPURIOUS", 1702 "Here's some traffic in the form of an error.")) 1703 else: 1704 s.waiting.callback({}) 1705 p.flush() 1706 c.transport.loseConnection() # close it 1707 p.flush() 1708 self.failUnless(self.testSucceeded) 1709 1710 1711 def test_protocolSwitchDeferred(self): 1712 """ 1713 Verify that protocol-switching even works if the value returned from 1714 the command that does the switch is deferred. 1715 """ 1716 return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtocol) 1717 1718 1719 def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol): 1720 """ 1721 Verify that if we try to switch protocols and it fails, the connection 1722 stays up and we can go back to speaking AMP. 1723 """ 1724 self.testSucceeded = False 1725 1726 serverDeferred = defer.Deferred() 1727 serverProto = switcher(serverDeferred) 1728 clientDeferred = defer.Deferred() 1729 clientProto = switcher(clientDeferred) 1730 c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto, 1731 ClientClass=lambda: clientProto) 1732 L = [] 1733 c.switchToTestProtocol(fail=True).addErrback(L.append) 1734 p.flush() 1735 L.pop().trap(UnknownProtocol) 1736 self.failIf(self.testSucceeded) 1737 # It's a known error, so let's send a "hello" on the same connection; 1738 # it should work. 1739 c.sendHello('world').addCallback(L.append) 1740 p.flush() 1741 self.assertEqual(L.pop()['hello'], 'world') 1742 1743 1744 def test_trafficAfterSwitch(self): 1745 """ 1746 Verify that attempts to send traffic after a switch will not corrupt 1747 the nested protocol. 1748 """ 1749 return self.test_protocolSwitch(spuriousTraffic=True) 1750 1751 1752 def test_errorAfterSwitch(self): 1753 """ 1754 Returning an error after a protocol switch should record the underlying 1755 error. 1756 """ 1757 return self.test_protocolSwitch(spuriousTraffic=True, 1758 spuriousError=True) 1759 1760 1761 def test_quitBoxQuits(self): 1762 """ 1763 Verify that commands with a responseType of QuitBox will in fact 1764 terminate the connection. 1765 """ 1766 c, s, p = connectedServerAndClient( 1767 ServerClass=SimpleSymmetricCommandProtocol, 1768 ClientClass=SimpleSymmetricCommandProtocol) 1769 1770 L = [] 1771 HELLO = 'world' 1772 GOODBYE = 'everyone' 1773 c.sendHello(HELLO).addCallback(L.append) 1774 p.flush() 1775 self.assertEqual(L.pop()['hello'], HELLO) 1776 c.callRemote(Goodbye).addCallback(L.append) 1777 p.flush() 1778 self.assertEqual(L.pop()['goodbye'], GOODBYE) 1779 c.sendHello(HELLO).addErrback(L.append) 1780 L.pop().trap(error.ConnectionDone) 1781 1782 1783 def test_basicLiteralEmit(self): 1784 """ 1785 Verify that the command dictionaries for a callRemoteN look correct 1786 after being serialized and parsed. 1787 """ 1788 c, s, p = connectedServerAndClient() 1789 L = [] 1790 s.ampBoxReceived = L.append 1791 c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test', 1792 dash_arg='x', underscore_arg='y') 1793 p.flush() 1794 self.assertEqual(len(L), 1) 1795 for k, v in [('_command', Hello.commandName), 1796 ('hello', 'hello test'), 1797 ('mixedCase', 'mixed case arg test'), 1798 ('dash-arg', 'x'), 1799 ('underscore_arg', 'y')]: 1800 self.assertEqual(L[-1].pop(k), v) 1801 L[-1].pop('_ask') 1802 self.assertEqual(L[-1], {}) 1803 1804 1805 def test_basicStructuredEmit(self): 1806 """ 1807 Verify that a call similar to basicLiteralEmit's is handled properly with 1808 high-level quoting and passing to Python methods, and that argument 1809 names are correctly handled. 1810 """ 1811 L = [] 1812 class StructuredHello(amp.AMP): 1813 def h(self, *a, **k): 1814 L.append((a, k)) 1815 return dict(hello='aaa') 1816 Hello.responder(h) 1817 c, s, p = connectedServerAndClient(ServerClass=StructuredHello) 1818 c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test', 1819 dash_arg='x', underscore_arg='y').addCallback(L.append) 1820 p.flush() 1821 self.assertEqual(len(L), 2) 1822 self.assertEqual(L[0], 1823 ((), dict( 1824 hello='hello test', 1825 mixedCase='mixed case arg test', 1826 dash_arg='x', 1827 underscore_arg='y', 1828 From=s.transport.getPeer(), 1829 1830 # XXX - should optional arguments just not be passed? 1831 # passing None seems a little odd, looking at the way it 1832 # turns out here... -glyph 1833 Print=None, 1834 optional=None, 1835 ))) 1836 self.assertEqual(L[1], dict(Print=None, hello='aaa')) 1837 1838class PretendRemoteCertificateAuthority: 1839 def checkIsPretendRemote(self): 1840 return True 1841 1842class IOSimCert: 1843 verifyCount = 0 1844 1845 def options(self, *ign): 1846 return self 1847 1848 def iosimVerify(self, otherCert): 1849 """ 1850 This isn't a real certificate, and wouldn't work on a real socket, but 1851 iosim specifies a different API so that we don't have to do any crypto 1852 math to demonstrate that the right functions get called in the right 1853 places. 1854 """ 1855 assert otherCert is self 1856 self.verifyCount += 1 1857 return True 1858 1859class OKCert(IOSimCert): 1860 def options(self, x): 1861 assert x.checkIsPretendRemote() 1862 return self 1863 1864class GrumpyCert(IOSimCert): 1865 def iosimVerify(self, otherCert): 1866 self.verifyCount += 1 1867 return False 1868 1869class DroppyCert(IOSimCert): 1870 def __init__(self, toDrop): 1871 self.toDrop = toDrop 1872 1873 def iosimVerify(self, otherCert): 1874 self.verifyCount += 1 1875 self.toDrop.loseConnection() 1876 return True 1877 1878class SecurableProto(FactoryNotifier): 1879 1880 factory = None 1881 1882 def verifyFactory(self): 1883 return [PretendRemoteCertificateAuthority()] 1884 1885 def getTLSVars(self): 1886 cert = self.certFactory() 1887 verify = self.verifyFactory() 1888 return dict( 1889 tls_localCertificate=cert, 1890 tls_verifyAuthorities=verify) 1891 amp.StartTLS.responder(getTLSVars) 1892 1893 1894 1895class TLSTest(unittest.TestCase): 1896 def test_startingTLS(self): 1897 """ 1898 Verify that starting TLS and succeeding at handshaking sends all the 1899 notifications to all the right places. 1900 """ 1901 cli, svr, p = connectedServerAndClient( 1902 ServerClass=SecurableProto, 1903 ClientClass=SecurableProto) 1904 1905 okc = OKCert() 1906 svr.certFactory = lambda : okc 1907 1908 cli.callRemote( 1909 amp.StartTLS, tls_localCertificate=okc, 1910 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()]) 1911 1912 # let's buffer something to be delivered securely 1913 L = [] 1914 cli.callRemote(SecuredPing).addCallback(L.append) 1915 p.flush() 1916 # once for client once for server 1917 self.assertEqual(okc.verifyCount, 2) 1918 L = [] 1919 cli.callRemote(SecuredPing).addCallback(L.append) 1920 p.flush() 1921 self.assertEqual(L[0], {'pinged': True}) 1922 1923 1924 def test_startTooManyTimes(self): 1925 """ 1926 Verify that the protocol will complain if we attempt to renegotiate TLS, 1927 which we don't support. 1928 """ 1929 cli, svr, p = connectedServerAndClient( 1930 ServerClass=SecurableProto, 1931 ClientClass=SecurableProto) 1932 1933 okc = OKCert() 1934 svr.certFactory = lambda : okc 1935 1936 cli.callRemote(amp.StartTLS, 1937 tls_localCertificate=okc, 1938 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()]) 1939 p.flush() 1940 cli.noPeerCertificate = True # this is totally fake 1941 self.assertRaises( 1942 amp.OnlyOneTLS, 1943 cli.callRemote, 1944 amp.StartTLS, 1945 tls_localCertificate=okc, 1946 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()]) 1947 1948 1949 def test_negotiationFailed(self): 1950 """ 1951 Verify that starting TLS and failing on both sides at handshaking sends 1952 notifications to all the right places and terminates the connection. 1953 """ 1954 1955 badCert = GrumpyCert() 1956 1957 cli, svr, p = connectedServerAndClient( 1958 ServerClass=SecurableProto, 1959 ClientClass=SecurableProto) 1960 svr.certFactory = lambda : badCert 1961 1962 cli.callRemote(amp.StartTLS, 1963 tls_localCertificate=badCert) 1964 1965 p.flush() 1966 # once for client once for server - but both fail 1967 self.assertEqual(badCert.verifyCount, 2) 1968 d = cli.callRemote(SecuredPing) 1969 p.flush() 1970 self.assertFailure(d, iosim.NativeOpenSSLError) 1971 1972 1973 def test_negotiationFailedByClosing(self): 1974 """ 1975 Verify that starting TLS and failing by way of a lost connection 1976 notices that it is probably an SSL problem. 1977 """ 1978 1979 cli, svr, p = connectedServerAndClient( 1980 ServerClass=SecurableProto, 1981 ClientClass=SecurableProto) 1982 droppyCert = DroppyCert(svr.transport) 1983 svr.certFactory = lambda : droppyCert 1984 1985 cli.callRemote(amp.StartTLS, tls_localCertificate=droppyCert) 1986 1987 p.flush() 1988 1989 self.assertEqual(droppyCert.verifyCount, 2) 1990 1991 d = cli.callRemote(SecuredPing) 1992 p.flush() 1993 1994 # it might be a good idea to move this exception somewhere more 1995 # reasonable. 1996 self.assertFailure(d, error.PeerVerifyError) 1997 1998 skip = skipSSL 1999 2000 2001 2002class TLSNotAvailableTest(unittest.TestCase): 2003 """ 2004 Tests what happened when ssl is not available in current installation. 2005 """ 2006 2007 def setUp(self): 2008 """ 2009 Disable ssl in amp. 2010 """ 2011 self.ssl = amp.ssl 2012 amp.ssl = None 2013 2014 2015 def tearDown(self): 2016 """ 2017 Restore ssl module. 2018 """ 2019 amp.ssl = self.ssl 2020 2021 2022 def test_callRemoteError(self): 2023 """ 2024 Check that callRemote raises an exception when called with a 2025 L{amp.StartTLS}. 2026 """ 2027 cli, svr, p = connectedServerAndClient( 2028 ServerClass=SecurableProto, 2029 ClientClass=SecurableProto) 2030 2031 okc = OKCert() 2032 svr.certFactory = lambda : okc 2033 2034 return self.assertFailure(cli.callRemote( 2035 amp.StartTLS, tls_localCertificate=okc, 2036 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()]), 2037 RuntimeError) 2038 2039 2040 def test_messageReceivedError(self): 2041 """ 2042 When a client with SSL enabled talks to a server without SSL, it 2043 should return a meaningful error. 2044 """ 2045 svr = SecurableProto() 2046 okc = OKCert() 2047 svr.certFactory = lambda : okc 2048 box = amp.Box() 2049 box['_command'] = 'StartTLS' 2050 box['_ask'] = '1' 2051 boxes = [] 2052 svr.sendBox = boxes.append 2053 svr.makeConnection(StringTransport()) 2054 svr.ampBoxReceived(box) 2055 self.assertEqual(boxes, 2056 [{'_error_code': 'TLS_ERROR', 2057 '_error': '1', 2058 '_error_description': 'TLS not available'}]) 2059 2060 2061 2062class InheritedError(Exception): 2063 """ 2064 This error is used to check inheritance. 2065 """ 2066 2067 2068 2069class OtherInheritedError(Exception): 2070 """ 2071 This is a distinct error for checking inheritance. 2072 """ 2073 2074 2075 2076class BaseCommand(amp.Command): 2077 """ 2078 This provides a command that will be subclassed. 2079 """ 2080 errors = {InheritedError: 'INHERITED_ERROR'} 2081 2082 2083 2084class InheritedCommand(BaseCommand): 2085 """ 2086 This is a command which subclasses another command but does not override 2087 anything. 2088 """ 2089 2090 2091 2092class AddErrorsCommand(BaseCommand): 2093 """ 2094 This is a command which subclasses another command but adds errors to the 2095 list. 2096 """ 2097 arguments = [('other', amp.Boolean())] 2098 errors = {OtherInheritedError: 'OTHER_INHERITED_ERROR'} 2099 2100 2101 2102class NormalCommandProtocol(amp.AMP): 2103 """ 2104 This is a protocol which responds to L{BaseCommand}, and is used to test 2105 that inheritance does not interfere with the normal handling of errors. 2106 """ 2107 def resp(self): 2108 raise InheritedError() 2109 BaseCommand.responder(resp) 2110 2111 2112 2113class InheritedCommandProtocol(amp.AMP): 2114 """ 2115 This is a protocol which responds to L{InheritedCommand}, and is used to 2116 test that inherited commands inherit their bases' errors if they do not 2117 respond to any of their own. 2118 """ 2119 def resp(self): 2120 raise InheritedError() 2121 InheritedCommand.responder(resp) 2122 2123 2124 2125class AddedCommandProtocol(amp.AMP): 2126 """ 2127 This is a protocol which responds to L{AddErrorsCommand}, and is used to 2128 test that inherited commands can add their own new types of errors, but 2129 still respond in the same way to their parents types of errors. 2130 """ 2131 def resp(self, other): 2132 if other: 2133 raise OtherInheritedError() 2134 else: 2135 raise InheritedError() 2136 AddErrorsCommand.responder(resp) 2137 2138 2139 2140class CommandInheritanceTests(unittest.TestCase): 2141 """ 2142 These tests verify that commands inherit error conditions properly. 2143 """ 2144 2145 def errorCheck(self, err, proto, cmd, **kw): 2146 """ 2147 Check that the appropriate kind of error is raised when a given command 2148 is sent to a given protocol. 2149 """ 2150 c, s, p = connectedServerAndClient(ServerClass=proto, 2151 ClientClass=proto) 2152 d = c.callRemote(cmd, **kw) 2153 d2 = self.failUnlessFailure(d, err) 2154 p.flush() 2155 return d2 2156 2157 2158 def test_basicErrorPropagation(self): 2159 """ 2160 Verify that errors specified in a superclass are respected normally 2161 even if it has subclasses. 2162 """ 2163 return self.errorCheck( 2164 InheritedError, NormalCommandProtocol, BaseCommand) 2165 2166 2167 def test_inheritedErrorPropagation(self): 2168 """ 2169 Verify that errors specified in a superclass command are propagated to 2170 its subclasses. 2171 """ 2172 return self.errorCheck( 2173 InheritedError, InheritedCommandProtocol, InheritedCommand) 2174 2175 2176 def test_inheritedErrorAddition(self): 2177 """ 2178 Verify that new errors specified in a subclass of an existing command 2179 are honored even if the superclass defines some errors. 2180 """ 2181 return self.errorCheck( 2182 OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=True) 2183 2184 2185 def test_additionWithOriginalError(self): 2186 """ 2187 Verify that errors specified in a command's superclass are respected 2188 even if that command defines new errors itself. 2189 """ 2190 return self.errorCheck( 2191 InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False) 2192 2193 2194def _loseAndPass(err, proto): 2195 # be specific, pass on the error to the client. 2196 err.trap(error.ConnectionLost, error.ConnectionDone) 2197 del proto.connectionLost 2198 proto.connectionLost(err) 2199 2200 2201class LiveFireBase: 2202 """ 2203 Utility for connected reactor-using tests. 2204 """ 2205 2206 def setUp(self): 2207 """ 2208 Create an amp server and connect a client to it. 2209 """ 2210 from twisted.internet import reactor 2211 self.serverFactory = protocol.ServerFactory() 2212 self.serverFactory.protocol = self.serverProto 2213 self.clientFactory = protocol.ClientFactory() 2214 self.clientFactory.protocol = self.clientProto 2215 self.clientFactory.onMade = defer.Deferred() 2216 self.serverFactory.onMade = defer.Deferred() 2217 self.serverPort = reactor.listenTCP(0, self.serverFactory) 2218 self.addCleanup(self.serverPort.stopListening) 2219 self.clientConn = reactor.connectTCP( 2220 '127.0.0.1', self.serverPort.getHost().port, 2221 self.clientFactory) 2222 self.addCleanup(self.clientConn.disconnect) 2223 def getProtos(rlst): 2224 self.cli = self.clientFactory.theProto 2225 self.svr = self.serverFactory.theProto 2226 dl = defer.DeferredList([self.clientFactory.onMade, 2227 self.serverFactory.onMade]) 2228 return dl.addCallback(getProtos) 2229 2230 def tearDown(self): 2231 """ 2232 Cleanup client and server connections, and check the error got at 2233 C{connectionLost}. 2234 """ 2235 L = [] 2236 for conn in self.cli, self.svr: 2237 if conn.transport is not None: 2238 # depend on amp's function connection-dropping behavior 2239 d = defer.Deferred().addErrback(_loseAndPass, conn) 2240 conn.connectionLost = d.errback 2241 conn.transport.loseConnection() 2242 L.append(d) 2243 return defer.gatherResults(L 2244 ).addErrback(lambda first: first.value.subFailure) 2245 2246 2247def show(x): 2248 import sys 2249 sys.stdout.write(x+'\n') 2250 sys.stdout.flush() 2251 2252 2253def tempSelfSigned(): 2254 from twisted.internet import ssl 2255 2256 sharedDN = ssl.DN(CN='shared') 2257 key = ssl.KeyPair.generate() 2258 cr = key.certificateRequest(sharedDN) 2259 sscrd = key.signCertificateRequest( 2260 sharedDN, cr, lambda dn: True, 1234567) 2261 cert = key.newCertificate(sscrd) 2262 return cert 2263 2264if ssl is not None: 2265 tempcert = tempSelfSigned() 2266 2267 2268class LiveFireTLSTestCase(LiveFireBase, unittest.TestCase): 2269 clientProto = SecurableProto 2270 serverProto = SecurableProto 2271 def test_liveFireCustomTLS(self): 2272 """ 2273 Using real, live TLS, actually negotiate a connection. 2274 2275 This also looks at the 'peerCertificate' attribute's correctness, since 2276 that's actually loaded using OpenSSL calls, but the main purpose is to 2277 make sure that we didn't miss anything obvious in iosim about TLS 2278 negotiations. 2279 """ 2280 2281 cert = tempcert 2282 2283 self.svr.verifyFactory = lambda : [cert] 2284 self.svr.certFactory = lambda : cert 2285 # only needed on the server, we specify the client below. 2286 2287 def secured(rslt): 2288 x = cert.digest() 2289 def pinged(rslt2): 2290 # Interesting. OpenSSL won't even _tell_ us about the peer 2291 # cert until we negotiate. we should be able to do this in 2292 # 'secured' instead, but it looks like we can't. I think this 2293 # is a bug somewhere far deeper than here. 2294 self.assertEqual(x, self.cli.hostCertificate.digest()) 2295 self.assertEqual(x, self.cli.peerCertificate.digest()) 2296 self.assertEqual(x, self.svr.hostCertificate.digest()) 2297 self.assertEqual(x, self.svr.peerCertificate.digest()) 2298 return self.cli.callRemote(SecuredPing).addCallback(pinged) 2299 return self.cli.callRemote(amp.StartTLS, 2300 tls_localCertificate=cert, 2301 tls_verifyAuthorities=[cert]).addCallback(secured) 2302 2303 skip = skipSSL 2304 2305 2306 2307class SlightlySmartTLS(SimpleSymmetricCommandProtocol): 2308 """ 2309 Specific implementation of server side protocol with different 2310 management of TLS. 2311 """ 2312 def getTLSVars(self): 2313 """ 2314 @return: the global C{tempcert} certificate as local certificate. 2315 """ 2316 return dict(tls_localCertificate=tempcert) 2317 amp.StartTLS.responder(getTLSVars) 2318 2319 2320class PlainVanillaLiveFire(LiveFireBase, unittest.TestCase): 2321 2322 clientProto = SimpleSymmetricCommandProtocol 2323 serverProto = SimpleSymmetricCommandProtocol 2324 2325 def test_liveFireDefaultTLS(self): 2326 """ 2327 Verify that out of the box, we can start TLS to at least encrypt the 2328 connection, even if we don't have any certificates to use. 2329 """ 2330 def secured(result): 2331 return self.cli.callRemote(SecuredPing) 2332 return self.cli.callRemote(amp.StartTLS).addCallback(secured) 2333 2334 skip = skipSSL 2335 2336 2337 2338class WithServerTLSVerification(LiveFireBase, unittest.TestCase): 2339 clientProto = SimpleSymmetricCommandProtocol 2340 serverProto = SlightlySmartTLS 2341 2342 def test_anonymousVerifyingClient(self): 2343 """ 2344 Verify that anonymous clients can verify server certificates. 2345 """ 2346 def secured(result): 2347 return self.cli.callRemote(SecuredPing) 2348 return self.cli.callRemote(amp.StartTLS, 2349 tls_verifyAuthorities=[tempcert] 2350 ).addCallback(secured) 2351 2352 skip = skipSSL 2353 2354 2355 2356class ProtocolIncludingArgument(amp.Argument): 2357 """ 2358 An L{amp.Argument} which encodes its parser and serializer 2359 arguments *including the protocol* into its parsed and serialized 2360 forms. 2361 """ 2362 2363 def fromStringProto(self, string, protocol): 2364 """ 2365 Don't decode anything; just return all possible information. 2366 2367 @return: A two-tuple of the input string and the protocol. 2368 """ 2369 return (string, protocol) 2370 2371 def toStringProto(self, obj, protocol): 2372 """ 2373 Encode identifying information about L{object} and protocol 2374 into a string for later verification. 2375 2376 @type obj: L{object} 2377 @type protocol: L{amp.AMP} 2378 """ 2379 return "%s:%s" % (id(obj), id(protocol)) 2380 2381 2382 2383class ProtocolIncludingCommand(amp.Command): 2384 """ 2385 A command that has argument and response schemas which use 2386 L{ProtocolIncludingArgument}. 2387 """ 2388 arguments = [('weird', ProtocolIncludingArgument())] 2389 response = [('weird', ProtocolIncludingArgument())] 2390 2391 2392 2393class MagicSchemaCommand(amp.Command): 2394 """ 2395 A command which overrides L{parseResponse}, L{parseArguments}, and 2396 L{makeResponse}. 2397 """ 2398 def parseResponse(self, strings, protocol): 2399 """ 2400 Don't do any parsing, just jam the input strings and protocol 2401 onto the C{protocol.parseResponseArguments} attribute as a 2402 two-tuple. Return the original strings. 2403 """ 2404 protocol.parseResponseArguments = (strings, protocol) 2405 return strings 2406 parseResponse = classmethod(parseResponse) 2407 2408 2409 def parseArguments(cls, strings, protocol): 2410 """ 2411 Don't do any parsing, just jam the input strings and protocol 2412 onto the C{protocol.parseArgumentsArguments} attribute as a 2413 two-tuple. Return the original strings. 2414 """ 2415 protocol.parseArgumentsArguments = (strings, protocol) 2416 return strings 2417 parseArguments = classmethod(parseArguments) 2418 2419 2420 def makeArguments(cls, objects, protocol): 2421 """ 2422 Don't do any serializing, just jam the input strings and protocol 2423 onto the C{protocol.makeArgumentsArguments} attribute as a 2424 two-tuple. Return the original strings. 2425 """ 2426 protocol.makeArgumentsArguments = (objects, protocol) 2427 return objects 2428 makeArguments = classmethod(makeArguments) 2429 2430 2431 2432class NoNetworkProtocol(amp.AMP): 2433 """ 2434 An L{amp.AMP} subclass which overrides private methods to avoid 2435 testing the network. It also provides a responder for 2436 L{MagicSchemaCommand} that does nothing, so that tests can test 2437 aspects of the interaction of L{amp.Command}s and L{amp.AMP}. 2438 2439 @ivar parseArgumentsArguments: Arguments that have been passed to any 2440 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by 2441 this protocol. 2442 2443 @ivar parseResponseArguments: Responses that have been returned from a 2444 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by 2445 this protocol. 2446 2447 @ivar makeArgumentsArguments: Arguments that have been serialized by any 2448 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by 2449 this protocol. 2450 """ 2451 def _sendBoxCommand(self, commandName, strings, requiresAnswer): 2452 """ 2453 Return a Deferred which fires with the original strings. 2454 """ 2455 return defer.succeed(strings) 2456 2457 MagicSchemaCommand.responder(lambda s, weird: {}) 2458 2459 2460 2461class MyBox(dict): 2462 """ 2463 A unique dict subclass. 2464 """ 2465 2466 2467 2468class ProtocolIncludingCommandWithDifferentCommandType( 2469 ProtocolIncludingCommand): 2470 """ 2471 A L{ProtocolIncludingCommand} subclass whose commandType is L{MyBox} 2472 """ 2473 commandType = MyBox 2474 2475 2476 2477class CommandTestCase(unittest.TestCase): 2478 """ 2479 Tests for L{amp.Argument} and L{amp.Command}. 2480 """ 2481 def test_argumentInterface(self): 2482 """ 2483 L{Argument} instances provide L{amp.IArgumentType}. 2484 """ 2485 self.assertTrue(verifyObject(amp.IArgumentType, amp.Argument())) 2486 2487 2488 def test_parseResponse(self): 2489 """ 2490 There should be a class method of Command which accepts a 2491 mapping of argument names to serialized forms and returns a 2492 similar mapping whose values have been parsed via the 2493 Command's response schema. 2494 """ 2495 protocol = object() 2496 result = 'whatever' 2497 strings = {'weird': result} 2498 self.assertEqual( 2499 ProtocolIncludingCommand.parseResponse(strings, protocol), 2500 {'weird': (result, protocol)}) 2501 2502 2503 def test_callRemoteCallsParseResponse(self): 2504 """ 2505 Making a remote call on a L{amp.Command} subclass which 2506 overrides the C{parseResponse} method should call that 2507 C{parseResponse} method to get the response. 2508 """ 2509 client = NoNetworkProtocol() 2510 thingy = "weeoo" 2511 response = client.callRemote(MagicSchemaCommand, weird=thingy) 2512 def gotResponse(ign): 2513 self.assertEqual(client.parseResponseArguments, 2514 ({"weird": thingy}, client)) 2515 response.addCallback(gotResponse) 2516 return response 2517 2518 2519 def test_parseArguments(self): 2520 """ 2521 There should be a class method of L{amp.Command} which accepts 2522 a mapping of argument names to serialized forms and returns a 2523 similar mapping whose values have been parsed via the 2524 command's argument schema. 2525 """ 2526 protocol = object() 2527 result = 'whatever' 2528 strings = {'weird': result} 2529 self.assertEqual( 2530 ProtocolIncludingCommand.parseArguments(strings, protocol), 2531 {'weird': (result, protocol)}) 2532 2533 2534 def test_responderCallsParseArguments(self): 2535 """ 2536 Making a remote call on a L{amp.Command} subclass which 2537 overrides the C{parseArguments} method should call that 2538 C{parseArguments} method to get the arguments. 2539 """ 2540 protocol = NoNetworkProtocol() 2541 responder = protocol.locateResponder(MagicSchemaCommand.commandName) 2542 argument = object() 2543 response = responder(dict(weird=argument)) 2544 response.addCallback( 2545 lambda ign: self.assertEqual(protocol.parseArgumentsArguments, 2546 ({"weird": argument}, protocol))) 2547 return response 2548 2549 2550 def test_makeArguments(self): 2551 """ 2552 There should be a class method of L{amp.Command} which accepts 2553 a mapping of argument names to objects and returns a similar 2554 mapping whose values have been serialized via the command's 2555 argument schema. 2556 """ 2557 protocol = object() 2558 argument = object() 2559 objects = {'weird': argument} 2560 self.assertEqual( 2561 ProtocolIncludingCommand.makeArguments(objects, protocol), 2562 {'weird': "%d:%d" % (id(argument), id(protocol))}) 2563 2564 2565 def test_makeArgumentsUsesCommandType(self): 2566 """ 2567 L{amp.Command.makeArguments}'s return type should be the type 2568 of the result of L{amp.Command.commandType}. 2569 """ 2570 protocol = object() 2571 objects = {"weird": "whatever"} 2572 2573 result = ProtocolIncludingCommandWithDifferentCommandType.makeArguments( 2574 objects, protocol) 2575 self.assertIdentical(type(result), MyBox) 2576 2577 2578 def test_callRemoteCallsMakeArguments(self): 2579 """ 2580 Making a remote call on a L{amp.Command} subclass which 2581 overrides the C{makeArguments} method should call that 2582 C{makeArguments} method to get the response. 2583 """ 2584 client = NoNetworkProtocol() 2585 argument = object() 2586 response = client.callRemote(MagicSchemaCommand, weird=argument) 2587 def gotResponse(ign): 2588 self.assertEqual(client.makeArgumentsArguments, 2589 ({"weird": argument}, client)) 2590 response.addCallback(gotResponse) 2591 return response 2592 2593 2594 def test_extraArgumentsDisallowed(self): 2595 """ 2596 L{Command.makeArguments} raises L{amp.InvalidSignature} if the objects 2597 dictionary passed to it includes a key which does not correspond to the 2598 Python identifier for a defined argument. 2599 """ 2600 self.assertRaises( 2601 amp.InvalidSignature, 2602 Hello.makeArguments, 2603 dict(hello="hello", bogusArgument=object()), None) 2604 2605 2606 def test_wireSpellingDisallowed(self): 2607 """ 2608 If a command argument conflicts with a Python keyword, the 2609 untransformed argument name is not allowed as a key in the dictionary 2610 passed to L{Command.makeArguments}. If it is supplied, 2611 L{amp.InvalidSignature} is raised. 2612 2613 This may be a pointless implementation restriction which may be lifted. 2614 The current behavior is tested to verify that such arguments are not 2615 silently dropped on the floor (the previous behavior). 2616 """ 2617 self.assertRaises( 2618 amp.InvalidSignature, 2619 Hello.makeArguments, 2620 dict(hello="required", **{"print": "print value"}), 2621 None) 2622 2623 2624class ListOfTestsMixin: 2625 """ 2626 Base class for testing L{ListOf}, a parameterized zero-or-more argument 2627 type. 2628 2629 @ivar elementType: Subclasses should set this to an L{Argument} 2630 instance. The tests will make a L{ListOf} using this. 2631 2632 @ivar strings: Subclasses should set this to a dictionary mapping some 2633 number of keys to the correct serialized form for some example 2634 values. These should agree with what L{elementType} 2635 produces/accepts. 2636 2637 @ivar objects: Subclasses should set this to a dictionary with the same 2638 keys as C{strings} and with values which are the lists which should 2639 serialize to the values in the C{strings} dictionary. 2640 """ 2641 def test_toBox(self): 2642 """ 2643 L{ListOf.toBox} extracts the list of objects from the C{objects} 2644 dictionary passed to it, using the C{name} key also passed to it, 2645 serializes each of the elements in that list using the L{Argument} 2646 instance previously passed to its initializer, combines the serialized 2647 results, and inserts the result into the C{strings} dictionary using 2648 the same C{name} key. 2649 """ 2650 stringList = amp.ListOf(self.elementType) 2651 strings = amp.AmpBox() 2652 for key in self.objects: 2653 stringList.toBox(key, strings, self.objects.copy(), None) 2654 self.assertEqual(strings, self.strings) 2655 2656 2657 def test_fromBox(self): 2658 """ 2659 L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}. 2660 """ 2661 stringList = amp.ListOf(self.elementType) 2662 objects = {} 2663 for key in self.strings: 2664 stringList.fromBox(key, self.strings.copy(), objects, None) 2665 self.assertEqual(objects, self.objects) 2666 2667 2668 2669class ListOfStringsTests(unittest.TestCase, ListOfTestsMixin): 2670 """ 2671 Tests for L{ListOf} combined with L{amp.String}. 2672 """ 2673 elementType = amp.String() 2674 2675 strings = { 2676 "empty": "", 2677 "single": "\x00\x03foo", 2678 "multiple": "\x00\x03bar\x00\x03baz\x00\x04quux"} 2679 2680 objects = { 2681 "empty": [], 2682 "single": ["foo"], 2683 "multiple": ["bar", "baz", "quux"]} 2684 2685 2686class ListOfIntegersTests(unittest.TestCase, ListOfTestsMixin): 2687 """ 2688 Tests for L{ListOf} combined with L{amp.Integer}. 2689 """ 2690 elementType = amp.Integer() 2691 2692 huge = ( 2693 9999999999999999999999999999999999999999999999999999999999 * 2694 9999999999999999999999999999999999999999999999999999999999) 2695 2696 strings = { 2697 "empty": "", 2698 "single": "\x00\x0210", 2699 "multiple": "\x00\x011\x00\x0220\x00\x03500", 2700 "huge": "\x00\x74%d" % (huge,), 2701 "negative": "\x00\x02-1"} 2702 2703 objects = { 2704 "empty": [], 2705 "single": [10], 2706 "multiple": [1, 20, 500], 2707 "huge": [huge], 2708 "negative": [-1]} 2709 2710 2711 2712class ListOfUnicodeTests(unittest.TestCase, ListOfTestsMixin): 2713 """ 2714 Tests for L{ListOf} combined with L{amp.Unicode}. 2715 """ 2716 elementType = amp.Unicode() 2717 2718 strings = { 2719 "empty": "", 2720 "single": "\x00\x03foo", 2721 "multiple": "\x00\x03\xe2\x98\x83\x00\x05Hello\x00\x05world"} 2722 2723 objects = { 2724 "empty": [], 2725 "single": [u"foo"], 2726 "multiple": [u"\N{SNOWMAN}", u"Hello", u"world"]} 2727 2728 2729 2730class ListOfDecimalTests(unittest.TestCase, ListOfTestsMixin): 2731 """ 2732 Tests for L{ListOf} combined with L{amp.Decimal}. 2733 """ 2734 elementType = amp.Decimal() 2735 2736 strings = { 2737 "empty": "", 2738 "single": "\x00\x031.1", 2739 "extreme": "\x00\x08Infinity\x00\x09-Infinity", 2740 "scientist": "\x00\x083.141E+5\x00\x0a0.00003141\x00\x083.141E-7" 2741 "\x00\x09-3.141E+5\x00\x0b-0.00003141\x00\x09-3.141E-7", 2742 "engineer": "\x00\x04%s\x00\x06%s" % ( 2743 decimal.Decimal("0e6").to_eng_string(), 2744 decimal.Decimal("1.5E-9").to_eng_string()), 2745 } 2746 2747 objects = { 2748 "empty": [], 2749 "single": [decimal.Decimal("1.1")], 2750 "extreme": [ 2751 decimal.Decimal("Infinity"), 2752 decimal.Decimal("-Infinity"), 2753 ], 2754 # exarkun objected to AMP supporting engineering notation because 2755 # it was redundant, until we realised that 1E6 has less precision 2756 # than 1000000 and is represented differently. But they compare 2757 # and even hash equally. There were tears. 2758 "scientist": [ 2759 decimal.Decimal("3.141E5"), 2760 decimal.Decimal("3.141e-5"), 2761 decimal.Decimal("3.141E-7"), 2762 decimal.Decimal("-3.141e5"), 2763 decimal.Decimal("-3.141E-5"), 2764 decimal.Decimal("-3.141e-7"), 2765 ], 2766 "engineer": [ 2767 decimal.Decimal("0e6"), 2768 decimal.Decimal("1.5E-9"), 2769 ], 2770 } 2771 2772 2773 2774class ListOfDecimalNanTests(unittest.TestCase, ListOfTestsMixin): 2775 """ 2776 Tests for L{ListOf} combined with L{amp.Decimal} for not-a-number values. 2777 """ 2778 elementType = amp.Decimal() 2779 2780 strings = { 2781 "nan": "\x00\x03NaN\x00\x04-NaN\x00\x04sNaN\x00\x05-sNaN", 2782 } 2783 2784 objects = { 2785 "nan": [ 2786 decimal.Decimal("NaN"), 2787 decimal.Decimal("-NaN"), 2788 decimal.Decimal("sNaN"), 2789 decimal.Decimal("-sNaN"), 2790 ] 2791 } 2792 2793 def test_fromBox(self): 2794 """ 2795 L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}. 2796 """ 2797 # Helpers. Decimal.is_{qnan,snan,signed}() are new in 2.6 (or 2.5.2, 2798 # but who's counting). 2799 def is_qnan(decimal): 2800 return 'NaN' in str(decimal) and 'sNaN' not in str(decimal) 2801 2802 def is_snan(decimal): 2803 return 'sNaN' in str(decimal) 2804 2805 def is_signed(decimal): 2806 return '-' in str(decimal) 2807 2808 # NaN values have unusual equality semantics, so this method is 2809 # overridden to compare the resulting objects in a way which works with 2810 # NaNs. 2811 stringList = amp.ListOf(self.elementType) 2812 objects = {} 2813 for key in self.strings: 2814 stringList.fromBox(key, self.strings.copy(), objects, None) 2815 n = objects["nan"] 2816 self.assertTrue(is_qnan(n[0]) and not is_signed(n[0])) 2817 self.assertTrue(is_qnan(n[1]) and is_signed(n[1])) 2818 self.assertTrue(is_snan(n[2]) and not is_signed(n[2])) 2819 self.assertTrue(is_snan(n[3]) and is_signed(n[3])) 2820 2821 2822 2823class DecimalTests(unittest.TestCase): 2824 """ 2825 Tests for L{amp.Decimal}. 2826 """ 2827 def test_nonDecimal(self): 2828 """ 2829 L{amp.Decimal.toString} raises L{ValueError} if passed an object which 2830 is not an instance of C{decimal.Decimal}. 2831 """ 2832 argument = amp.Decimal() 2833 self.assertRaises(ValueError, argument.toString, "1.234") 2834 self.assertRaises(ValueError, argument.toString, 1.234) 2835 self.assertRaises(ValueError, argument.toString, 1234) 2836 2837 2838 2839class ListOfDateTimeTests(unittest.TestCase, ListOfTestsMixin): 2840 """ 2841 Tests for L{ListOf} combined with L{amp.DateTime}. 2842 """ 2843 elementType = amp.DateTime() 2844 2845 strings = { 2846 "christmas": 2847 "\x00\x202010-12-25T00:00:00.000000-00:00" 2848 "\x00\x202010-12-25T00:00:00.000000-00:00", 2849 "christmas in eu": "\x00\x202010-12-25T00:00:00.000000+01:00", 2850 "christmas in iran": "\x00\x202010-12-25T00:00:00.000000+03:30", 2851 "christmas in nyc": "\x00\x202010-12-25T00:00:00.000000-05:00", 2852 "previous tests": "\x00\x202010-12-25T00:00:00.000000+03:19" 2853 "\x00\x202010-12-25T00:00:00.000000-06:59", 2854 } 2855 2856 objects = { 2857 "christmas": [ 2858 datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=amp.utc), 2859 datetime.datetime(2010, 12, 25, 0, 0, 0, 2860 tzinfo=amp._FixedOffsetTZInfo('+', 0, 0)), 2861 ], 2862 "christmas in eu": [ 2863 datetime.datetime(2010, 12, 25, 0, 0, 0, 2864 tzinfo=amp._FixedOffsetTZInfo('+', 1, 0)), 2865 ], 2866 "christmas in iran": [ 2867 datetime.datetime(2010, 12, 25, 0, 0, 0, 2868 tzinfo=amp._FixedOffsetTZInfo('+', 3, 30)), 2869 ], 2870 "christmas in nyc": [ 2871 datetime.datetime(2010, 12, 25, 0, 0, 0, 2872 tzinfo=amp._FixedOffsetTZInfo('-', 5, 0)), 2873 ], 2874 "previous tests": [ 2875 datetime.datetime(2010, 12, 25, 0, 0, 0, 2876 tzinfo=amp._FixedOffsetTZInfo('+', 3, 19)), 2877 datetime.datetime(2010, 12, 25, 0, 0, 0, 2878 tzinfo=amp._FixedOffsetTZInfo('-', 6, 59)), 2879 ], 2880 } 2881 2882 2883 2884class ListOfOptionalTests(unittest.TestCase): 2885 """ 2886 Tests to ensure L{ListOf} AMP arguments can be omitted from AMP commands 2887 via the 'optional' flag. 2888 """ 2889 def test_requiredArgumentWithNoneValueRaisesTypeError(self): 2890 """ 2891 L{ListOf.toBox} raises C{TypeError} when passed a value of C{None} 2892 for the argument. 2893 """ 2894 stringList = amp.ListOf(amp.Integer()) 2895 self.assertRaises( 2896 TypeError, stringList.toBox, 'omitted', amp.AmpBox(), 2897 {'omitted': None}, None) 2898 2899 2900 def test_optionalArgumentWithNoneValueOmitted(self): 2901 """ 2902 L{ListOf.toBox} silently omits serializing any argument with a 2903 value of C{None} that is designated as optional for the protocol. 2904 """ 2905 stringList = amp.ListOf(amp.Integer(), optional=True) 2906 strings = amp.AmpBox() 2907 stringList.toBox('omitted', strings, {'omitted': None}, None) 2908 self.assertEqual(strings, {}) 2909 2910 2911 def test_requiredArgumentWithKeyMissingRaisesKeyError(self): 2912 """ 2913 L{ListOf.toBox} raises C{KeyError} if the argument's key is not 2914 present in the objects dictionary. 2915 """ 2916 stringList = amp.ListOf(amp.Integer()) 2917 self.assertRaises( 2918 KeyError, stringList.toBox, 'ommited', amp.AmpBox(), 2919 {'someOtherKey': 0}, None) 2920 2921 2922 def test_optionalArgumentWithKeyMissingOmitted(self): 2923 """ 2924 L{ListOf.toBox} silently omits serializing any argument designated 2925 as optional whose key is not present in the objects dictionary. 2926 """ 2927 stringList = amp.ListOf(amp.Integer(), optional=True) 2928 stringList.toBox('ommited', amp.AmpBox(), {'someOtherKey': 0}, None) 2929 2930 2931 def test_omittedOptionalArgumentDeserializesAsNone(self): 2932 """ 2933 L{ListOf.fromBox} correctly reverses the operation performed by 2934 L{ListOf.toBox} for optional arguments. 2935 """ 2936 stringList = amp.ListOf(amp.Integer(), optional=True) 2937 objects = {} 2938 stringList.fromBox('omitted', {}, objects, None) 2939 self.assertEqual(objects, {'omitted': None}) 2940 2941 2942 2943class UNIXStringTransport(object): 2944 """ 2945 An in-memory implementation of L{interfaces.IUNIXTransport} which collects 2946 all data given to it for later inspection. 2947 2948 @ivar _queue: A C{list} of the data which has been given to this transport, 2949 eg via C{write} or C{sendFileDescriptor}. Elements are two-tuples of a 2950 string (identifying the destination of the data) and the data itself. 2951 """ 2952 implements(interfaces.IUNIXTransport) 2953 2954 def __init__(self, descriptorFuzz): 2955 """ 2956 @param descriptorFuzz: An offset to apply to descriptors. 2957 @type descriptorFuzz: C{int} 2958 """ 2959 self._fuzz = descriptorFuzz 2960 self._queue = [] 2961 2962 2963 def sendFileDescriptor(self, descriptor): 2964 self._queue.append(( 2965 'fileDescriptorReceived', descriptor + self._fuzz)) 2966 2967 2968 def write(self, data): 2969 self._queue.append(('dataReceived', data)) 2970 2971 2972 def writeSequence(self, seq): 2973 for data in seq: 2974 self.write(data) 2975 2976 2977 def loseConnection(self): 2978 self._queue.append(('connectionLost', Failure(ConnectionLost()))) 2979 2980 2981 def getHost(self): 2982 return UNIXAddress('/tmp/some-path') 2983 2984 2985 def getPeer(self): 2986 return UNIXAddress('/tmp/another-path') 2987 2988# Minimal evidence that we got the signatures right 2989verifyClass(interfaces.ITransport, UNIXStringTransport) 2990verifyClass(interfaces.IUNIXTransport, UNIXStringTransport) 2991 2992 2993class DescriptorTests(unittest.TestCase): 2994 """ 2995 Tests for L{amp.Descriptor}, an argument type for passing a file descriptor 2996 over an AMP connection over a UNIX domain socket. 2997 """ 2998 def setUp(self): 2999 self.fuzz = 3 3000 self.transport = UNIXStringTransport(descriptorFuzz=self.fuzz) 3001 self.protocol = amp.BinaryBoxProtocol( 3002 amp.BoxDispatcher(amp.CommandLocator())) 3003 self.protocol.makeConnection(self.transport) 3004 3005 3006 def test_fromStringProto(self): 3007 """ 3008 L{Descriptor.fromStringProto} constructs a file descriptor value by 3009 extracting a previously received file descriptor corresponding to the 3010 wire value of the argument from the L{_DescriptorExchanger} state of the 3011 protocol passed to it. 3012 3013 This is a whitebox test which involves direct L{_DescriptorExchanger} 3014 state inspection. 3015 """ 3016 argument = amp.Descriptor() 3017 self.protocol.fileDescriptorReceived(5) 3018 self.protocol.fileDescriptorReceived(3) 3019 self.protocol.fileDescriptorReceived(1) 3020 self.assertEqual( 3021 5, argument.fromStringProto("0", self.protocol)) 3022 self.assertEqual( 3023 3, argument.fromStringProto("1", self.protocol)) 3024 self.assertEqual( 3025 1, argument.fromStringProto("2", self.protocol)) 3026 self.assertEqual({}, self.protocol._descriptors) 3027 3028 3029 def test_toStringProto(self): 3030 """ 3031 To send a file descriptor, L{Descriptor.toStringProto} uses the 3032 L{IUNIXTransport.sendFileDescriptor} implementation of the transport of 3033 the protocol passed to it to copy the file descriptor. Each subsequent 3034 descriptor sent over a particular AMP connection is assigned the next 3035 integer value, starting from 0. The base ten string representation of 3036 this value is the byte encoding of the argument. 3037 3038 This is a whitebox test which involves direct L{_DescriptorExchanger} 3039 state inspection and mutation. 3040 """ 3041 argument = amp.Descriptor() 3042 self.assertEqual("0", argument.toStringProto(2, self.protocol)) 3043 self.assertEqual( 3044 ("fileDescriptorReceived", 2 + self.fuzz), self.transport._queue.pop(0)) 3045 self.assertEqual("1", argument.toStringProto(4, self.protocol)) 3046 self.assertEqual( 3047 ("fileDescriptorReceived", 4 + self.fuzz), self.transport._queue.pop(0)) 3048 self.assertEqual("2", argument.toStringProto(6, self.protocol)) 3049 self.assertEqual( 3050 ("fileDescriptorReceived", 6 + self.fuzz), self.transport._queue.pop(0)) 3051 self.assertEqual({}, self.protocol._descriptors) 3052 3053 3054 def test_roundTrip(self): 3055 """ 3056 L{amp.Descriptor.fromBox} can interpret an L{amp.AmpBox} constructed by 3057 L{amp.Descriptor.toBox} to reconstruct a file descriptor value. 3058 """ 3059 name = "alpha" 3060 strings = {} 3061 descriptor = 17 3062 sendObjects = {name: descriptor} 3063 3064 argument = amp.Descriptor() 3065 argument.toBox(name, strings, sendObjects.copy(), self.protocol) 3066 3067 receiver = amp.BinaryBoxProtocol( 3068 amp.BoxDispatcher(amp.CommandLocator())) 3069 for event in self.transport._queue: 3070 getattr(receiver, event[0])(*event[1:]) 3071 3072 receiveObjects = {} 3073 argument.fromBox(name, strings.copy(), receiveObjects, receiver) 3074 3075 # Make sure we got the descriptor. Adjust by fuzz to be more convincing 3076 # of having gone through L{IUNIXTransport.sendFileDescriptor}, not just 3077 # converted to a string and then parsed back into an integer. 3078 self.assertEqual(descriptor + self.fuzz, receiveObjects[name]) 3079 3080 3081 3082class DateTimeTests(unittest.TestCase): 3083 """ 3084 Tests for L{amp.DateTime}, L{amp._FixedOffsetTZInfo}, and L{amp.utc}. 3085 """ 3086 string = '9876-01-23T12:34:56.054321-01:23' 3087 tzinfo = amp._FixedOffsetTZInfo('-', 1, 23) 3088 object = datetime.datetime(9876, 1, 23, 12, 34, 56, 54321, tzinfo) 3089 3090 def test_invalidString(self): 3091 """ 3092 L{amp.DateTime.fromString} raises L{ValueError} when passed a string 3093 which does not represent a timestamp in the proper format. 3094 """ 3095 d = amp.DateTime() 3096 self.assertRaises(ValueError, d.fromString, 'abc') 3097 3098 3099 def test_invalidDatetime(self): 3100 """ 3101 L{amp.DateTime.toString} raises L{ValueError} when passed a naive 3102 datetime (a datetime with no timezone information). 3103 """ 3104 d = amp.DateTime() 3105 self.assertRaises(ValueError, d.toString, 3106 datetime.datetime(2010, 12, 25, 0, 0, 0)) 3107 3108 3109 def test_fromString(self): 3110 """ 3111 L{amp.DateTime.fromString} returns a C{datetime.datetime} with all of 3112 its fields populated from the string passed to it. 3113 """ 3114 argument = amp.DateTime() 3115 value = argument.fromString(self.string) 3116 self.assertEqual(value, self.object) 3117 3118 3119 def test_toString(self): 3120 """ 3121 L{amp.DateTime.toString} returns a C{str} in the wire format including 3122 all of the information from the C{datetime.datetime} passed into it, 3123 including the timezone offset. 3124 """ 3125 argument = amp.DateTime() 3126 value = argument.toString(self.object) 3127 self.assertEqual(value, self.string) 3128 3129 3130 3131class FixedOffsetTZInfoTests(unittest.TestCase): 3132 """ 3133 Tests for L{amp._FixedOffsetTZInfo} and L{amp.utc}. 3134 """ 3135 3136 def test_tzname(self): 3137 """ 3138 L{amp.utc.tzname} returns C{"+00:00"}. 3139 """ 3140 self.assertEqual(amp.utc.tzname(None), '+00:00') 3141 3142 3143 def test_dst(self): 3144 """ 3145 L{amp.utc.dst} returns a zero timedelta. 3146 """ 3147 self.assertEqual(amp.utc.dst(None), datetime.timedelta(0)) 3148 3149 3150 def test_utcoffset(self): 3151 """ 3152 L{amp.utc.utcoffset} returns a zero timedelta. 3153 """ 3154 self.assertEqual(amp.utc.utcoffset(None), datetime.timedelta(0)) 3155 3156 3157 def test_badSign(self): 3158 """ 3159 L{amp._FixedOffsetTZInfo} raises L{ValueError} if passed an offset sign 3160 other than C{'+'} or C{'-'}. 3161 """ 3162 self.assertRaises(ValueError, amp._FixedOffsetTZInfo, '?', 0, 0) 3163 3164 3165 3166if not interfaces.IReactorSSL.providedBy(reactor): 3167 skipMsg = 'This test case requires SSL support in the reactor' 3168 TLSTest.skip = skipMsg 3169 LiveFireTLSTestCase.skip = skipMsg 3170 PlainVanillaLiveFire.skip = skipMsg 3171 WithServerTLSVerification.skip = skipMsg 3172