1# Copyright (c) 2005 Divmod, Inc. 2# Copyright (c) Twisted Matrix Laboratories. 3# See LICENSE for details. 4 5""" 6Tests for L{twisted.protocols.amp}. 7""" 8 9 10import datetime 11import decimal 12from typing import Dict, Type 13from unittest import skipIf 14 15from zope.interface import implementer 16from zope.interface.verify import verifyClass, verifyObject 17 18from twisted.internet import address, defer, error, interfaces, protocol, reactor 19from twisted.protocols import amp 20from twisted.python import filepath 21from twisted.python.failure import Failure 22from twisted.test import iosim 23from twisted.test.proto_helpers import StringTransport 24from twisted.trial.unittest import TestCase 25 26try: 27 from twisted.internet import ssl as _ssl 28except ImportError: 29 ssl = None 30else: 31 if not _ssl.supported: 32 ssl = None 33 else: 34 ssl = _ssl 35 36if ssl is None: 37 skipSSL = True 38else: 39 skipSSL = False 40 41if not interfaces.IReactorSSL.providedBy(reactor): 42 reactorLacksSSL = True 43else: 44 reactorLacksSSL = False 45 46 47tz = amp._FixedOffsetTZInfo.fromSignHoursMinutes 48 49 50class TestProto(protocol.Protocol): 51 """ 52 A trivial protocol for use in testing where a L{Protocol} is expected. 53 54 @ivar instanceId: the id of this instance 55 @ivar onConnLost: deferred that will fired when the connection is lost 56 @ivar dataToSend: data to send on the protocol 57 """ 58 59 instanceCount = 0 60 61 def __init__(self, onConnLost, dataToSend): 62 assert isinstance(dataToSend, bytes), repr(dataToSend) 63 self.onConnLost = onConnLost 64 self.dataToSend = dataToSend 65 self.instanceId = TestProto.instanceCount 66 TestProto.instanceCount = TestProto.instanceCount + 1 67 68 def connectionMade(self): 69 self.data = [] 70 self.transport.write(self.dataToSend) 71 72 def dataReceived(self, bytes): 73 self.data.append(bytes) 74 75 def connectionLost(self, reason): 76 self.onConnLost.callback(self.data) 77 78 def __repr__(self) -> str: 79 """ 80 Custom repr for testing to avoid coupling amp tests with repr from 81 L{Protocol} 82 83 Returns a string which contains a unique identifier that can be looked 84 up using the instanceId property:: 85 86 <TestProto #3> 87 """ 88 return "<TestProto #%d>" % (self.instanceId,) 89 90 91class SimpleSymmetricProtocol(amp.AMP): 92 def sendHello(self, text): 93 return self.callRemoteString(b"hello", hello=text) 94 95 def amp_HELLO(self, box): 96 return amp.Box(hello=box[b"hello"]) 97 98 99class UnfriendlyGreeting(Exception): 100 """Greeting was insufficiently kind.""" 101 102 103class DeathThreat(Exception): 104 """Greeting was insufficiently kind.""" 105 106 107class UnknownProtocol(Exception): 108 """Asked to switch to the wrong protocol.""" 109 110 111class TransportPeer(amp.Argument): 112 # this serves as some informal documentation for how to get variables from 113 # the protocol or your environment and pass them to methods as arguments. 114 def retrieve(self, d, name, proto): 115 return b"" 116 117 def fromStringProto(self, notAString, proto): 118 return proto.transport.getPeer() 119 120 def toBox(self, name, strings, objects, proto): 121 return 122 123 124class Hello(amp.Command): 125 126 commandName = b"hello" 127 128 arguments = [ 129 (b"hello", amp.String()), 130 (b"optional", amp.Boolean(optional=True)), 131 (b"print", amp.Unicode(optional=True)), 132 (b"from", TransportPeer(optional=True)), 133 (b"mixedCase", amp.String(optional=True)), 134 (b"dash-arg", amp.String(optional=True)), 135 (b"underscore_arg", amp.String(optional=True)), 136 ] 137 138 response = [(b"hello", amp.String()), (b"print", amp.Unicode(optional=True))] 139 140 errors: Dict[Type[Exception], bytes] = {UnfriendlyGreeting: b"UNFRIENDLY"} 141 142 fatalErrors: Dict[Type[Exception], bytes] = {DeathThreat: b"DEAD"} 143 144 145class NoAnswerHello(Hello): 146 commandName = Hello.commandName 147 requiresAnswer = False 148 149 150class FutureHello(amp.Command): 151 commandName = b"hello" 152 153 arguments = [ 154 (b"hello", amp.String()), 155 (b"optional", amp.Boolean(optional=True)), 156 (b"print", amp.Unicode(optional=True)), 157 (b"from", TransportPeer(optional=True)), 158 (b"bonus", amp.String(optional=True)), # addt'l arguments 159 # should generally be 160 # added at the end, and 161 # be optional... 162 ] 163 164 response = [(b"hello", amp.String()), (b"print", amp.Unicode(optional=True))] 165 166 errors = {UnfriendlyGreeting: b"UNFRIENDLY"} 167 168 169class WTF(amp.Command): 170 """ 171 An example of an invalid command. 172 """ 173 174 175class BrokenReturn(amp.Command): 176 """An example of a perfectly good command, but the handler is going to return 177 None... 178 """ 179 180 commandName = b"broken_return" 181 182 183class Goodbye(amp.Command): 184 # commandName left blank on purpose: this tests implicit command names. 185 response = [(b"goodbye", amp.String())] 186 responseType = amp.QuitBox 187 188 189class WaitForever(amp.Command): 190 commandName = b"wait_forever" 191 192 193class GetList(amp.Command): 194 commandName = b"getlist" 195 arguments = [(b"length", amp.Integer())] 196 response = [(b"body", amp.AmpList([(b"x", amp.Integer())]))] 197 198 199class DontRejectMe(amp.Command): 200 commandName = b"dontrejectme" 201 arguments = [ 202 (b"magicWord", amp.Unicode()), 203 (b"list", amp.AmpList([(b"name", amp.Unicode())], optional=True)), 204 ] 205 response = [(b"response", amp.Unicode())] 206 207 208class SecuredPing(amp.Command): 209 # XXX TODO: actually make this refuse to send over an insecure connection 210 response = [(b"pinged", amp.Boolean())] 211 212 213class TestSwitchProto(amp.ProtocolSwitchCommand): 214 commandName = b"Switch-Proto" 215 216 arguments = [ 217 (b"name", amp.String()), 218 ] 219 errors = {UnknownProtocol: b"UNKNOWN"} 220 221 222class SingleUseFactory(protocol.ClientFactory): 223 def __init__(self, proto): 224 self.proto = proto 225 self.proto.factory = self 226 227 def buildProtocol(self, addr): 228 p, self.proto = self.proto, None 229 return p 230 231 reasonFailed = None 232 233 def clientConnectionFailed(self, connector, reason): 234 self.reasonFailed = reason 235 return 236 237 238THING_I_DONT_UNDERSTAND = b"gwebol nargo" 239 240 241class ThingIDontUnderstandError(Exception): 242 pass 243 244 245class FactoryNotifier(amp.AMP): 246 factory = None 247 248 def connectionMade(self): 249 if self.factory is not None: 250 self.factory.theProto = self 251 if hasattr(self.factory, "onMade"): 252 self.factory.onMade.callback(None) 253 254 def emitpong(self): 255 from twisted.internet.interfaces import ISSLTransport 256 257 if not ISSLTransport.providedBy(self.transport): 258 raise DeathThreat("only send secure pings over secure channels") 259 return {"pinged": True} 260 261 SecuredPing.responder(emitpong) 262 263 264class SimpleSymmetricCommandProtocol(FactoryNotifier): 265 maybeLater = None 266 267 def __init__(self, onConnLost=None): 268 amp.AMP.__init__(self) 269 self.onConnLost = onConnLost 270 271 def sendHello(self, text): 272 return self.callRemote(Hello, hello=text) 273 274 def sendUnicodeHello(self, text, translation): 275 return self.callRemote(Hello, hello=text, Print=translation) 276 277 greeted = False 278 279 def cmdHello( 280 self, 281 hello, 282 From, 283 optional=None, 284 Print=None, 285 mixedCase=None, 286 dash_arg=None, 287 underscore_arg=None, 288 ): 289 assert From == self.transport.getPeer() 290 if hello == THING_I_DONT_UNDERSTAND: 291 raise ThingIDontUnderstandError() 292 if hello.startswith(b"fuck"): 293 raise UnfriendlyGreeting("Don't be a dick.") 294 if hello == b"die": 295 raise DeathThreat("aieeeeeeeee") 296 result = dict(hello=hello) 297 if Print is not None: 298 result.update(dict(Print=Print)) 299 self.greeted = True 300 return result 301 302 Hello.responder(cmdHello) 303 304 def cmdGetlist(self, length): 305 return {"body": [dict(x=1)] * length} 306 307 GetList.responder(cmdGetlist) 308 309 def okiwont(self, magicWord, list=None): 310 if list is None: 311 response = "list omitted" 312 else: 313 response = "%s accepted" % (list[0]["name"]) 314 return dict(response=response) 315 316 DontRejectMe.responder(okiwont) 317 318 def waitforit(self): 319 self.waiting = defer.Deferred() 320 return self.waiting 321 322 WaitForever.responder(waitforit) 323 324 def saybye(self): 325 return dict(goodbye=b"everyone") 326 327 Goodbye.responder(saybye) 328 329 def switchToTestProtocol(self, fail=False): 330 if fail: 331 name = b"no-proto" 332 else: 333 name = b"test-proto" 334 p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA) 335 return self.callRemote( 336 TestSwitchProto, SingleUseFactory(p), name=name 337 ).addCallback(lambda ign: p) 338 339 def switchit(self, name): 340 if name == b"test-proto": 341 return TestProto(self.onConnLost, SWITCH_SERVER_DATA) 342 raise UnknownProtocol(name) 343 344 TestSwitchProto.responder(switchit) 345 346 def donothing(self): 347 return None 348 349 BrokenReturn.responder(donothing) 350 351 352class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol): 353 def switchit(self, name): 354 if name == b"test-proto": 355 self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA) 356 self.maybeLater = defer.Deferred() 357 return self.maybeLater 358 359 TestSwitchProto.responder(switchit) 360 361 362class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol): 363 def badResponder( 364 self, 365 hello, 366 From, 367 optional=None, 368 Print=None, 369 mixedCase=None, 370 dash_arg=None, 371 underscore_arg=None, 372 ): 373 """ 374 This responder does nothing and forgets to return a dictionary. 375 """ 376 377 NoAnswerHello.responder(badResponder) 378 379 380class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol): 381 def goodNoAnswerResponder( 382 self, 383 hello, 384 From, 385 optional=None, 386 Print=None, 387 mixedCase=None, 388 dash_arg=None, 389 underscore_arg=None, 390 ): 391 return dict(hello=hello + b"-noanswer") 392 393 NoAnswerHello.responder(goodNoAnswerResponder) 394 395 396def connectedServerAndClient( 397 ServerClass=SimpleSymmetricProtocol, ClientClass=SimpleSymmetricProtocol, *a, **kw 398): 399 """Returns a 3-tuple: (client, server, pump)""" 400 return iosim.connectedServerAndClient(ServerClass, ClientClass, *a, **kw) 401 402 403class TotallyDumbProtocol(protocol.Protocol): 404 buf = b"" 405 406 def dataReceived(self, data): 407 self.buf += data 408 409 410class LiteralAmp(amp.AMP): 411 def __init__(self): 412 self.boxes = [] 413 414 def ampBoxReceived(self, box): 415 self.boxes.append(box) 416 return 417 418 419class AmpBoxTests(TestCase): 420 """ 421 Test a few essential properties of AMP boxes, mostly with respect to 422 serialization correctness. 423 """ 424 425 def test_serializeStr(self): 426 """ 427 Make sure that strs serialize to strs. 428 """ 429 a = amp.AmpBox(key=b"value") 430 self.assertEqual(type(a.serialize()), bytes) 431 432 def test_serializeUnicodeKeyRaises(self): 433 """ 434 Verify that TypeError is raised when trying to serialize Unicode keys. 435 """ 436 a = amp.AmpBox(**{"key": "value"}) 437 self.assertRaises(TypeError, a.serialize) 438 439 def test_serializeUnicodeValueRaises(self): 440 """ 441 Verify that TypeError is raised when trying to serialize Unicode 442 values. 443 """ 444 a = amp.AmpBox(key="value") 445 self.assertRaises(TypeError, a.serialize) 446 447 448class ParsingTests(TestCase): 449 def test_booleanValues(self): 450 """ 451 Verify that the Boolean parser parses 'True' and 'False', but nothing 452 else. 453 """ 454 b = amp.Boolean() 455 self.assertTrue(b.fromString(b"True")) 456 self.assertFalse(b.fromString(b"False")) 457 self.assertRaises(TypeError, b.fromString, b"ninja") 458 self.assertRaises(TypeError, b.fromString, b"true") 459 self.assertRaises(TypeError, b.fromString, b"TRUE") 460 self.assertEqual(b.toString(True), b"True") 461 self.assertEqual(b.toString(False), b"False") 462 463 def test_pathValueRoundTrip(self): 464 """ 465 Verify the 'Path' argument can parse and emit a file path. 466 """ 467 fp = filepath.FilePath(self.mktemp()) 468 p = amp.Path() 469 s = p.toString(fp) 470 v = p.fromString(s) 471 self.assertIsNot(fp, v) # sanity check 472 self.assertEqual(fp, v) 473 474 def test_sillyEmptyThing(self): 475 """ 476 Test that empty boxes raise an error; they aren't supposed to be sent 477 on purpose. 478 """ 479 a = amp.AMP() 480 return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box()) 481 482 def test_ParsingRoundTrip(self): 483 """ 484 Verify that various kinds of data make it through the encode/parse 485 round-trip unharmed. 486 """ 487 c, s, p = connectedServerAndClient( 488 ClientClass=LiteralAmp, ServerClass=LiteralAmp 489 ) 490 491 SIMPLE = (b"simple", b"test") 492 CE = (b"ceq", b": ") 493 CR = (b"crtest", b"test\r") 494 LF = (b"lftest", b"hello\n") 495 NEWLINE = (b"newline", b"test\r\none\r\ntwo") 496 NEWLINE2 = (b"newline2", b"test\r\none\r\n two") 497 BODYTEST = (b"body", b"blah\r\n\r\ntesttest") 498 499 testData = [ 500 [SIMPLE], 501 [SIMPLE, BODYTEST], 502 [SIMPLE, CE], 503 [SIMPLE, CR], 504 [SIMPLE, CE, CR, LF], 505 [CE, CR, LF], 506 [SIMPLE, NEWLINE, CE, NEWLINE2], 507 [BODYTEST, SIMPLE, NEWLINE], 508 ] 509 510 for test in testData: 511 jb = amp.Box() 512 jb.update(dict(test)) 513 jb._sendTo(c) 514 p.flush() 515 self.assertEqual(s.boxes[-1], jb) 516 517 518class FakeLocator: 519 """ 520 This is a fake implementation of the interface implied by 521 L{CommandLocator}. 522 """ 523 524 def __init__(self): 525 """ 526 Remember the given keyword arguments as a set of responders. 527 """ 528 self.commands = {} 529 530 def locateResponder(self, commandName): 531 """ 532 Look up and return a function passed as a keyword argument of the given 533 name to the constructor. 534 """ 535 return self.commands[commandName] 536 537 538class FakeSender: 539 """ 540 This is a fake implementation of the 'box sender' interface implied by 541 L{AMP}. 542 """ 543 544 def __init__(self): 545 """ 546 Create a fake sender and initialize the list of received boxes and 547 unhandled errors. 548 """ 549 self.sentBoxes = [] 550 self.unhandledErrors = [] 551 self.expectedErrors = 0 552 553 def expectError(self): 554 """ 555 Expect one error, so that the test doesn't fail. 556 """ 557 self.expectedErrors += 1 558 559 def sendBox(self, box): 560 """ 561 Accept a box, but don't do anything. 562 """ 563 self.sentBoxes.append(box) 564 565 def unhandledError(self, failure): 566 """ 567 Deal with failures by instantly re-raising them for easier debugging. 568 """ 569 self.expectedErrors -= 1 570 if self.expectedErrors < 0: 571 failure.raiseException() 572 else: 573 self.unhandledErrors.append(failure) 574 575 576class CommandDispatchTests(TestCase): 577 """ 578 The AMP CommandDispatcher class dispatches converts AMP boxes into commands 579 and responses using Command.responder decorator. 580 581 Note: Originally, AMP's factoring was such that many tests for this 582 functionality are now implemented as full round-trip tests in L{AMPTests}. 583 Future tests should be written at this level instead, to ensure API 584 compatibility and to provide more granular, readable units of test 585 coverage. 586 """ 587 588 def setUp(self): 589 """ 590 Create a dispatcher to use. 591 """ 592 self.locator = FakeLocator() 593 self.sender = FakeSender() 594 self.dispatcher = amp.BoxDispatcher(self.locator) 595 self.dispatcher.startReceivingBoxes(self.sender) 596 597 def test_receivedAsk(self): 598 """ 599 L{CommandDispatcher.ampBoxReceived} should locate the appropriate 600 command in its responder lookup, based on the '_ask' key. 601 """ 602 received = [] 603 604 def thunk(box): 605 received.append(box) 606 return amp.Box({"hello": "goodbye"}) 607 608 input = amp.Box(_command="hello", _ask="test-command-id", hello="world") 609 self.locator.commands["hello"] = thunk 610 self.dispatcher.ampBoxReceived(input) 611 self.assertEqual(received, [input]) 612 613 def test_sendUnhandledError(self): 614 """ 615 L{CommandDispatcher} should relay its unhandled errors in responding to 616 boxes to its boxSender. 617 """ 618 err = RuntimeError("something went wrong, oh no") 619 self.sender.expectError() 620 self.dispatcher.unhandledError(Failure(err)) 621 self.assertEqual(len(self.sender.unhandledErrors), 1) 622 self.assertEqual(self.sender.unhandledErrors[0].value, err) 623 624 def test_unhandledSerializationError(self): 625 """ 626 Errors during serialization ought to be relayed to the sender's 627 unhandledError method. 628 """ 629 err = RuntimeError("something undefined went wrong") 630 631 def thunk(result): 632 class BrokenBox(amp.Box): 633 def _sendTo(self, proto): 634 raise err 635 636 return BrokenBox() 637 638 self.locator.commands["hello"] = thunk 639 input = amp.Box(_command="hello", _ask="test-command-id", hello="world") 640 self.sender.expectError() 641 self.dispatcher.ampBoxReceived(input) 642 self.assertEqual(len(self.sender.unhandledErrors), 1) 643 self.assertEqual(self.sender.unhandledErrors[0].value, err) 644 645 def test_callRemote(self): 646 """ 647 L{CommandDispatcher.callRemote} should emit a properly formatted '_ask' 648 box to its boxSender and record an outstanding L{Deferred}. When a 649 corresponding '_answer' packet is received, the L{Deferred} should be 650 fired, and the results translated via the given L{Command}'s response 651 de-serialization. 652 """ 653 D = self.dispatcher.callRemote(Hello, hello=b"world") 654 self.assertEqual( 655 self.sender.sentBoxes, 656 [amp.AmpBox(_command=b"hello", _ask=b"1", hello=b"world")], 657 ) 658 answers = [] 659 D.addCallback(answers.append) 660 self.assertEqual(answers, []) 661 self.dispatcher.ampBoxReceived( 662 amp.AmpBox({b"hello": b"yay", b"print": b"ignored", b"_answer": b"1"}) 663 ) 664 self.assertEqual(answers, [dict(hello=b"yay", Print="ignored")]) 665 666 def _localCallbackErrorLoggingTest(self, callResult): 667 """ 668 Verify that C{callResult} completes with a L{None} result and that an 669 unhandled error has been logged. 670 """ 671 finalResult = [] 672 callResult.addBoth(finalResult.append) 673 674 self.assertEqual(1, len(self.sender.unhandledErrors)) 675 self.assertIsInstance(self.sender.unhandledErrors[0].value, ZeroDivisionError) 676 677 self.assertEqual([None], finalResult) 678 679 def test_callRemoteSuccessLocalCallbackErrorLogging(self): 680 """ 681 If the last callback on the L{Deferred} returned by C{callRemote} (added 682 by application code calling C{callRemote}) fails, the failure is passed 683 to the sender's C{unhandledError} method. 684 """ 685 self.sender.expectError() 686 687 callResult = self.dispatcher.callRemote(Hello, hello=b"world") 688 callResult.addCallback(lambda result: 1 // 0) 689 690 self.dispatcher.ampBoxReceived( 691 amp.AmpBox({b"hello": b"yay", b"print": b"ignored", b"_answer": b"1"}) 692 ) 693 694 self._localCallbackErrorLoggingTest(callResult) 695 696 def test_callRemoteErrorLocalCallbackErrorLogging(self): 697 """ 698 Like L{test_callRemoteSuccessLocalCallbackErrorLogging}, but for the 699 case where the L{Deferred} returned by C{callRemote} fails. 700 """ 701 self.sender.expectError() 702 703 callResult = self.dispatcher.callRemote(Hello, hello=b"world") 704 callResult.addErrback(lambda result: 1 // 0) 705 706 self.dispatcher.ampBoxReceived( 707 amp.AmpBox( 708 { 709 b"_error": b"1", 710 b"_error_code": b"bugs", 711 b"_error_description": b"stuff", 712 } 713 ) 714 ) 715 716 self._localCallbackErrorLoggingTest(callResult) 717 718 719class SimpleGreeting(amp.Command): 720 """ 721 A very simple greeting command that uses a few basic argument types. 722 """ 723 724 commandName = b"simple" 725 arguments = [(b"greeting", amp.Unicode()), (b"cookie", amp.Integer())] 726 response = [(b"cookieplus", amp.Integer())] 727 728 729class TestLocator(amp.CommandLocator): 730 """ 731 A locator which implements a responder to the 'simple' command. 732 """ 733 734 def __init__(self): 735 self.greetings = [] 736 737 def greetingResponder(self, greeting, cookie): 738 self.greetings.append((greeting, cookie)) 739 return dict(cookieplus=cookie + 3) 740 741 greetingResponder = SimpleGreeting.responder(greetingResponder) 742 743 744class OverridingLocator(TestLocator): 745 """ 746 A locator which overrides the responder to the 'simple' command. 747 """ 748 749 def greetingResponder(self, greeting, cookie): 750 """ 751 Return a different cookieplus than L{TestLocator.greetingResponder}. 752 """ 753 self.greetings.append((greeting, cookie)) 754 return dict(cookieplus=cookie + 4) 755 756 greetingResponder = SimpleGreeting.responder(greetingResponder) 757 758 759class InheritingLocator(OverridingLocator): 760 """ 761 This locator should inherit the responder from L{OverridingLocator}. 762 """ 763 764 765class OverrideLocatorAMP(amp.AMP): 766 def __init__(self): 767 amp.AMP.__init__(self) 768 self.customResponder = object() 769 self.expectations = {b"custom": self.customResponder} 770 self.greetings = [] 771 772 def lookupFunction(self, name): 773 """ 774 Override the deprecated lookupFunction function. 775 """ 776 if name in self.expectations: 777 result = self.expectations[name] 778 return result 779 else: 780 return super().lookupFunction(name) 781 782 def greetingResponder(self, greeting, cookie): 783 self.greetings.append((greeting, cookie)) 784 return dict(cookieplus=cookie + 3) 785 786 greetingResponder = SimpleGreeting.responder(greetingResponder) 787 788 789class CommandLocatorTests(TestCase): 790 """ 791 The CommandLocator should enable users to specify responders to commands as 792 functions that take structured objects, annotated with metadata. 793 """ 794 795 def _checkSimpleGreeting(self, locatorClass, expected): 796 """ 797 Check that a locator of type C{locatorClass} finds a responder 798 for command named I{simple} and that the found responder answers 799 with the C{expected} result to a C{SimpleGreeting<"ni hao", 5>} 800 command. 801 """ 802 locator = locatorClass() 803 responderCallable = locator.locateResponder(b"simple") 804 result = responderCallable(amp.Box(greeting=b"ni hao", cookie=b"5")) 805 806 def done(values): 807 self.assertEqual(values, amp.AmpBox(cookieplus=b"%d" % (expected,))) 808 809 return result.addCallback(done) 810 811 def test_responderDecorator(self): 812 """ 813 A method on a L{CommandLocator} subclass decorated with a L{Command} 814 subclass's L{responder} decorator should be returned from 815 locateResponder, wrapped in logic to serialize and deserialize its 816 arguments. 817 """ 818 return self._checkSimpleGreeting(TestLocator, 8) 819 820 def test_responderOverriding(self): 821 """ 822 L{CommandLocator} subclasses can override a responder inherited from 823 a base class by using the L{Command.responder} decorator to register 824 a new responder method. 825 """ 826 return self._checkSimpleGreeting(OverridingLocator, 9) 827 828 def test_responderInheritance(self): 829 """ 830 Responder lookup follows the same rules as normal method lookup 831 rules, particularly with respect to inheritance. 832 """ 833 return self._checkSimpleGreeting(InheritingLocator, 9) 834 835 def test_lookupFunctionDeprecatedOverride(self): 836 """ 837 Subclasses which override locateResponder under its old name, 838 lookupFunction, should have the override invoked instead. (This tests 839 an AMP subclass, because in the version of the code that could invoke 840 this deprecated code path, there was no L{CommandLocator}.) 841 """ 842 locator = OverrideLocatorAMP() 843 customResponderObject = self.assertWarns( 844 PendingDeprecationWarning, 845 "Override locateResponder, not lookupFunction.", 846 __file__, 847 lambda: locator.locateResponder(b"custom"), 848 ) 849 self.assertEqual(locator.customResponder, customResponderObject) 850 # Make sure upcalling works too 851 normalResponderObject = self.assertWarns( 852 PendingDeprecationWarning, 853 "Override locateResponder, not lookupFunction.", 854 __file__, 855 lambda: locator.locateResponder(b"simple"), 856 ) 857 result = normalResponderObject(amp.Box(greeting=b"ni hao", cookie=b"5")) 858 859 def done(values): 860 self.assertEqual(values, amp.AmpBox(cookieplus=b"8")) 861 862 return result.addCallback(done) 863 864 def test_lookupFunctionDeprecatedInvoke(self): 865 """ 866 Invoking locateResponder under its old name, lookupFunction, should 867 emit a deprecation warning, but do the same thing. 868 """ 869 locator = TestLocator() 870 responderCallable = self.assertWarns( 871 PendingDeprecationWarning, 872 "Call locateResponder, not lookupFunction.", 873 __file__, 874 lambda: locator.lookupFunction(b"simple"), 875 ) 876 result = responderCallable(amp.Box(greeting=b"ni hao", cookie=b"5")) 877 878 def done(values): 879 self.assertEqual(values, amp.AmpBox(cookieplus=b"8")) 880 881 return result.addCallback(done) 882 883 884SWITCH_CLIENT_DATA = b"Success!" 885SWITCH_SERVER_DATA = b"No, really. Success." 886 887 888class BinaryProtocolTests(TestCase): 889 """ 890 Tests for L{amp.BinaryBoxProtocol}. 891 892 @ivar _boxSender: After C{startReceivingBoxes} is called, the L{IBoxSender} 893 which was passed to it. 894 """ 895 896 def setUp(self): 897 """ 898 Keep track of all boxes received by this test in its capacity as an 899 L{IBoxReceiver} implementor. 900 """ 901 self.boxes = [] 902 self.data = [] 903 904 def startReceivingBoxes(self, sender): 905 """ 906 Implement L{IBoxReceiver.startReceivingBoxes} to just remember the 907 value passed in. 908 """ 909 self._boxSender = sender 910 911 def ampBoxReceived(self, box): 912 """ 913 A box was received by the protocol. 914 """ 915 self.boxes.append(box) 916 917 stopReason = None 918 919 def stopReceivingBoxes(self, reason): 920 """ 921 Record the reason that we stopped receiving boxes. 922 """ 923 self.stopReason = reason 924 925 # fake ITransport 926 def getPeer(self): 927 return "no peer" 928 929 def getHost(self): 930 return "no host" 931 932 def write(self, data): 933 self.assertIsInstance(data, bytes) 934 self.data.append(data) 935 936 def test_startReceivingBoxes(self): 937 """ 938 When L{amp.BinaryBoxProtocol} is connected to a transport, it calls 939 C{startReceivingBoxes} on its L{IBoxReceiver} with itself as the 940 L{IBoxSender} parameter. 941 """ 942 protocol = amp.BinaryBoxProtocol(self) 943 protocol.makeConnection(None) 944 self.assertIs(self._boxSender, protocol) 945 946 def test_sendBoxInStartReceivingBoxes(self): 947 """ 948 The L{IBoxReceiver} which is started when L{amp.BinaryBoxProtocol} is 949 connected to a transport can call C{sendBox} on the L{IBoxSender} 950 passed to it before C{startReceivingBoxes} returns and have that box 951 sent. 952 """ 953 954 class SynchronouslySendingReceiver: 955 def startReceivingBoxes(self, sender): 956 sender.sendBox(amp.Box({b"foo": b"bar"})) 957 958 transport = StringTransport() 959 protocol = amp.BinaryBoxProtocol(SynchronouslySendingReceiver()) 960 protocol.makeConnection(transport) 961 self.assertEqual(transport.value(), b"\x00\x03foo\x00\x03bar\x00\x00") 962 963 def test_receiveBoxStateMachine(self): 964 """ 965 When a binary box protocol receives: 966 * a key 967 * a value 968 * an empty string 969 it should emit a box and send it to its boxReceiver. 970 """ 971 a = amp.BinaryBoxProtocol(self) 972 a.stringReceived(b"hello") 973 a.stringReceived(b"world") 974 a.stringReceived(b"") 975 self.assertEqual(self.boxes, [amp.AmpBox(hello=b"world")]) 976 977 def test_firstBoxFirstKeyExcessiveLength(self): 978 """ 979 L{amp.BinaryBoxProtocol} drops its connection if the length prefix for 980 the first a key it receives is larger than 255. 981 """ 982 transport = StringTransport() 983 protocol = amp.BinaryBoxProtocol(self) 984 protocol.makeConnection(transport) 985 protocol.dataReceived(b"\x01\x00") 986 self.assertTrue(transport.disconnecting) 987 988 def test_firstBoxSubsequentKeyExcessiveLength(self): 989 """ 990 L{amp.BinaryBoxProtocol} drops its connection if the length prefix for 991 a subsequent key in the first box it receives is larger than 255. 992 """ 993 transport = StringTransport() 994 protocol = amp.BinaryBoxProtocol(self) 995 protocol.makeConnection(transport) 996 protocol.dataReceived(b"\x00\x01k\x00\x01v") 997 self.assertFalse(transport.disconnecting) 998 protocol.dataReceived(b"\x01\x00") 999 self.assertTrue(transport.disconnecting) 1000 1001 def test_subsequentBoxFirstKeyExcessiveLength(self): 1002 """ 1003 L{amp.BinaryBoxProtocol} drops its connection if the length prefix for 1004 the first key in a subsequent box it receives is larger than 255. 1005 """ 1006 transport = StringTransport() 1007 protocol = amp.BinaryBoxProtocol(self) 1008 protocol.makeConnection(transport) 1009 protocol.dataReceived(b"\x00\x01k\x00\x01v\x00\x00") 1010 self.assertFalse(transport.disconnecting) 1011 protocol.dataReceived(b"\x01\x00") 1012 self.assertTrue(transport.disconnecting) 1013 1014 def test_excessiveKeyFailure(self): 1015 """ 1016 If L{amp.BinaryBoxProtocol} disconnects because it received a key 1017 length prefix which was too large, the L{IBoxReceiver}'s 1018 C{stopReceivingBoxes} method is called with a L{TooLong} failure. 1019 """ 1020 protocol = amp.BinaryBoxProtocol(self) 1021 protocol.makeConnection(StringTransport()) 1022 protocol.dataReceived(b"\x01\x00") 1023 protocol.connectionLost( 1024 Failure(error.ConnectionDone("simulated connection done")) 1025 ) 1026 self.stopReason.trap(amp.TooLong) 1027 self.assertTrue(self.stopReason.value.isKey) 1028 self.assertFalse(self.stopReason.value.isLocal) 1029 self.assertIsNone(self.stopReason.value.value) 1030 self.assertIsNone(self.stopReason.value.keyName) 1031 1032 def test_unhandledErrorWithTransport(self): 1033 """ 1034 L{amp.BinaryBoxProtocol.unhandledError} logs the failure passed to it 1035 and disconnects its transport. 1036 """ 1037 transport = StringTransport() 1038 protocol = amp.BinaryBoxProtocol(self) 1039 protocol.makeConnection(transport) 1040 protocol.unhandledError(Failure(RuntimeError("Fake error"))) 1041 self.assertEqual(1, len(self.flushLoggedErrors(RuntimeError))) 1042 self.assertTrue(transport.disconnecting) 1043 1044 def test_unhandledErrorWithoutTransport(self): 1045 """ 1046 L{amp.BinaryBoxProtocol.unhandledError} completes without error when 1047 there is no associated transport. 1048 """ 1049 protocol = amp.BinaryBoxProtocol(self) 1050 protocol.makeConnection(StringTransport()) 1051 protocol.connectionLost(Failure(Exception("Simulated"))) 1052 protocol.unhandledError(Failure(RuntimeError("Fake error"))) 1053 self.assertEqual(1, len(self.flushLoggedErrors(RuntimeError))) 1054 1055 def test_receiveBoxData(self): 1056 """ 1057 When a binary box protocol receives the serialized form of an AMP box, 1058 it should emit a similar box to its boxReceiver. 1059 """ 1060 a = amp.BinaryBoxProtocol(self) 1061 a.dataReceived( 1062 amp.Box( 1063 {b"testKey": b"valueTest", b"anotherKey": b"anotherValue"} 1064 ).serialize() 1065 ) 1066 self.assertEqual( 1067 self.boxes, 1068 [amp.Box({b"testKey": b"valueTest", b"anotherKey": b"anotherValue"})], 1069 ) 1070 1071 def test_receiveLongerBoxData(self): 1072 """ 1073 An L{amp.BinaryBoxProtocol} can receive serialized AMP boxes with 1074 values of up to (2 ** 16 - 1) bytes. 1075 """ 1076 length = 2 ** 16 - 1 1077 value = b"x" * length 1078 transport = StringTransport() 1079 protocol = amp.BinaryBoxProtocol(self) 1080 protocol.makeConnection(transport) 1081 protocol.dataReceived(amp.Box({"k": value}).serialize()) 1082 self.assertEqual(self.boxes, [amp.Box({"k": value})]) 1083 self.assertFalse(transport.disconnecting) 1084 1085 def test_sendBox(self): 1086 """ 1087 When a binary box protocol sends a box, it should emit the serialized 1088 bytes of that box to its transport. 1089 """ 1090 a = amp.BinaryBoxProtocol(self) 1091 a.makeConnection(self) 1092 aBox = amp.Box({b"testKey": b"valueTest", b"someData": b"hello"}) 1093 a.makeConnection(self) 1094 a.sendBox(aBox) 1095 self.assertEqual(b"".join(self.data), aBox.serialize()) 1096 1097 def test_connectionLostStopSendingBoxes(self): 1098 """ 1099 When a binary box protocol loses its connection, it should notify its 1100 box receiver that it has stopped receiving boxes. 1101 """ 1102 a = amp.BinaryBoxProtocol(self) 1103 a.makeConnection(self) 1104 connectionFailure = Failure(RuntimeError()) 1105 a.connectionLost(connectionFailure) 1106 self.assertIs(self.stopReason, connectionFailure) 1107 1108 def test_protocolSwitch(self): 1109 """ 1110 L{BinaryBoxProtocol} has the capacity to switch to a different protocol 1111 on a box boundary. When a protocol is in the process of switching, it 1112 cannot receive traffic. 1113 """ 1114 otherProto = TestProto(None, b"outgoing data") 1115 test = self 1116 1117 class SwitchyReceiver: 1118 switched = False 1119 1120 def startReceivingBoxes(self, sender): 1121 pass 1122 1123 def ampBoxReceived(self, box): 1124 test.assertFalse(self.switched, "Should only receive one box!") 1125 self.switched = True 1126 a._lockForSwitch() 1127 a._switchTo(otherProto) 1128 1129 a = amp.BinaryBoxProtocol(SwitchyReceiver()) 1130 anyOldBox = amp.Box({b"include": b"lots", b"of": b"data"}) 1131 a.makeConnection(self) 1132 # Include a 0-length box at the beginning of the next protocol's data, 1133 # to make sure that AMP doesn't eat the data or try to deliver extra 1134 # boxes either... 1135 moreThanOneBox = anyOldBox.serialize() + b"\x00\x00Hello, world!" 1136 a.dataReceived(moreThanOneBox) 1137 self.assertIs(otherProto.transport, self) 1138 self.assertEqual(b"".join(otherProto.data), b"\x00\x00Hello, world!") 1139 self.assertEqual(self.data, [b"outgoing data"]) 1140 a.dataReceived(b"more data") 1141 self.assertEqual(b"".join(otherProto.data), b"\x00\x00Hello, world!more data") 1142 self.assertRaises(amp.ProtocolSwitched, a.sendBox, anyOldBox) 1143 1144 def test_protocolSwitchEmptyBuffer(self): 1145 """ 1146 After switching to a different protocol, if no extra bytes beyond 1147 the switch box were delivered, an empty string is not passed to the 1148 switched protocol's C{dataReceived} method. 1149 """ 1150 a = amp.BinaryBoxProtocol(self) 1151 a.makeConnection(self) 1152 otherProto = TestProto(None, b"") 1153 a._switchTo(otherProto) 1154 self.assertEqual(otherProto.data, []) 1155 1156 def test_protocolSwitchInvalidStates(self): 1157 """ 1158 In order to make sure the protocol never gets any invalid data sent 1159 into the middle of a box, it must be locked for switching before it is 1160 switched. It can only be unlocked if the switch failed, and attempting 1161 to send a box while it is locked should raise an exception. 1162 """ 1163 a = amp.BinaryBoxProtocol(self) 1164 a.makeConnection(self) 1165 sampleBox = amp.Box({b"some": b"data"}) 1166 a._lockForSwitch() 1167 self.assertRaises(amp.ProtocolSwitched, a.sendBox, sampleBox) 1168 a._unlockFromSwitch() 1169 a.sendBox(sampleBox) 1170 self.assertEqual(b"".join(self.data), sampleBox.serialize()) 1171 a._lockForSwitch() 1172 otherProto = TestProto(None, b"outgoing data") 1173 a._switchTo(otherProto) 1174 self.assertRaises(amp.ProtocolSwitched, a._unlockFromSwitch) 1175 1176 def test_protocolSwitchLoseConnection(self): 1177 """ 1178 When the protocol is switched, it should notify its nested protocol of 1179 disconnection. 1180 """ 1181 1182 class Loser(protocol.Protocol): 1183 reason = None 1184 1185 def connectionLost(self, reason): 1186 self.reason = reason 1187 1188 connectionLoser = Loser() 1189 a = amp.BinaryBoxProtocol(self) 1190 a.makeConnection(self) 1191 a._lockForSwitch() 1192 a._switchTo(connectionLoser) 1193 connectionFailure = Failure(RuntimeError()) 1194 a.connectionLost(connectionFailure) 1195 self.assertEqual(connectionLoser.reason, connectionFailure) 1196 1197 def test_protocolSwitchLoseClientConnection(self): 1198 """ 1199 When the protocol is switched, it should notify its nested client 1200 protocol factory of disconnection. 1201 """ 1202 1203 class ClientLoser: 1204 reason = None 1205 1206 def clientConnectionLost(self, connector, reason): 1207 self.reason = reason 1208 1209 a = amp.BinaryBoxProtocol(self) 1210 connectionLoser = protocol.Protocol() 1211 clientLoser = ClientLoser() 1212 a.makeConnection(self) 1213 a._lockForSwitch() 1214 a._switchTo(connectionLoser, clientLoser) 1215 connectionFailure = Failure(RuntimeError()) 1216 a.connectionLost(connectionFailure) 1217 self.assertEqual(clientLoser.reason, connectionFailure) 1218 1219 1220class AMPTests(TestCase): 1221 def test_interfaceDeclarations(self): 1222 """ 1223 The classes in the amp module ought to implement the interfaces that 1224 are declared for their benefit. 1225 """ 1226 for interface, implementation in [ 1227 (amp.IBoxSender, amp.BinaryBoxProtocol), 1228 (amp.IBoxReceiver, amp.BoxDispatcher), 1229 (amp.IResponderLocator, amp.CommandLocator), 1230 (amp.IResponderLocator, amp.SimpleStringLocator), 1231 (amp.IBoxSender, amp.AMP), 1232 (amp.IBoxReceiver, amp.AMP), 1233 (amp.IResponderLocator, amp.AMP), 1234 ]: 1235 self.assertTrue( 1236 interface.implementedBy(implementation), 1237 f"{implementation} does not implements({interface})", 1238 ) 1239 1240 def test_helloWorld(self): 1241 """ 1242 Verify that a simple command can be sent and its response received with 1243 the simple low-level string-based API. 1244 """ 1245 c, s, p = connectedServerAndClient() 1246 L = [] 1247 HELLO = b"world" 1248 c.sendHello(HELLO).addCallback(L.append) 1249 p.flush() 1250 self.assertEqual(L[0][b"hello"], HELLO) 1251 1252 def test_wireFormatRoundTrip(self): 1253 """ 1254 Verify that mixed-case, underscored and dashed arguments are mapped to 1255 their python names properly. 1256 """ 1257 c, s, p = connectedServerAndClient() 1258 L = [] 1259 HELLO = b"world" 1260 c.sendHello(HELLO).addCallback(L.append) 1261 p.flush() 1262 self.assertEqual(L[0][b"hello"], HELLO) 1263 1264 def test_helloWorldUnicode(self): 1265 """ 1266 Verify that unicode arguments can be encoded and decoded. 1267 """ 1268 c, s, p = connectedServerAndClient( 1269 ServerClass=SimpleSymmetricCommandProtocol, 1270 ClientClass=SimpleSymmetricCommandProtocol, 1271 ) 1272 L = [] 1273 HELLO = b"world" 1274 HELLO_UNICODE = "wor\u1234ld" 1275 c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append) 1276 p.flush() 1277 self.assertEqual(L[0]["hello"], HELLO) 1278 self.assertEqual(L[0]["Print"], HELLO_UNICODE) 1279 1280 def test_callRemoteStringRequiresAnswerFalse(self): 1281 """ 1282 L{BoxDispatcher.callRemoteString} returns L{None} if C{requiresAnswer} 1283 is C{False}. 1284 """ 1285 c, s, p = connectedServerAndClient() 1286 ret = c.callRemoteString(b"WTF", requiresAnswer=False) 1287 self.assertIsNone(ret) 1288 1289 def test_unknownCommandLow(self): 1290 """ 1291 Verify that unknown commands using low-level APIs will be rejected with an 1292 error, but will NOT terminate the connection. 1293 """ 1294 c, s, p = connectedServerAndClient() 1295 L = [] 1296 1297 def clearAndAdd(e): 1298 """ 1299 You can't propagate the error... 1300 """ 1301 e.trap(amp.UnhandledCommand) 1302 return "OK" 1303 1304 c.callRemoteString(b"WTF").addErrback(clearAndAdd).addCallback(L.append) 1305 p.flush() 1306 self.assertEqual(L.pop(), "OK") 1307 HELLO = b"world" 1308 c.sendHello(HELLO).addCallback(L.append) 1309 p.flush() 1310 self.assertEqual(L[0][b"hello"], HELLO) 1311 1312 def test_unknownCommandHigh(self): 1313 """ 1314 Verify that unknown commands using high-level APIs will be rejected with an 1315 error, but will NOT terminate the connection. 1316 """ 1317 c, s, p = connectedServerAndClient() 1318 L = [] 1319 1320 def clearAndAdd(e): 1321 """ 1322 You can't propagate the error... 1323 """ 1324 e.trap(amp.UnhandledCommand) 1325 return "OK" 1326 1327 c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append) 1328 p.flush() 1329 self.assertEqual(L.pop(), "OK") 1330 HELLO = b"world" 1331 c.sendHello(HELLO).addCallback(L.append) 1332 p.flush() 1333 self.assertEqual(L[0][b"hello"], HELLO) 1334 1335 def test_brokenReturnValue(self): 1336 """ 1337 It can be very confusing if you write some code which responds to a 1338 command, but gets the return value wrong. Most commonly you end up 1339 returning None instead of a dictionary. 1340 1341 Verify that if that happens, the framework logs a useful error. 1342 """ 1343 L = [] 1344 SimpleSymmetricCommandProtocol().dispatchCommand( 1345 amp.AmpBox(_command=BrokenReturn.commandName) 1346 ).addErrback(L.append) 1347 L[0].trap(amp.BadLocalReturn) 1348 self.failUnlessIn("None", repr(L[0].value)) 1349 1350 def test_unknownArgument(self): 1351 """ 1352 Verify that unknown arguments are ignored, and not passed to a Python 1353 function which can't accept them. 1354 """ 1355 c, s, p = connectedServerAndClient( 1356 ServerClass=SimpleSymmetricCommandProtocol, 1357 ClientClass=SimpleSymmetricCommandProtocol, 1358 ) 1359 L = [] 1360 HELLO = b"world" 1361 # c.sendHello(HELLO).addCallback(L.append) 1362 c.callRemote( 1363 FutureHello, hello=HELLO, bonus=b"I'm not in the book!" 1364 ).addCallback(L.append) 1365 p.flush() 1366 self.assertEqual(L[0]["hello"], HELLO) 1367 1368 def test_simpleReprs(self): 1369 """ 1370 Verify that the various Box objects repr properly, for debugging. 1371 """ 1372 self.assertEqual(type(repr(amp._SwitchBox("a"))), str) 1373 self.assertEqual(type(repr(amp.QuitBox())), str) 1374 self.assertEqual(type(repr(amp.AmpBox())), str) 1375 self.assertIn("AmpBox", repr(amp.AmpBox())) 1376 1377 def test_innerProtocolInRepr(self): 1378 """ 1379 Verify that L{AMP} objects output their innerProtocol when set. 1380 """ 1381 otherProto = TestProto(None, b"outgoing data") 1382 a = amp.AMP() 1383 a.innerProtocol = otherProto 1384 1385 self.assertEqual( 1386 repr(a), 1387 "<AMP inner <TestProto #%d> at 0x%x>" % (otherProto.instanceId, id(a)), 1388 ) 1389 1390 def test_innerProtocolNotInRepr(self): 1391 """ 1392 Verify that L{AMP} objects do not output 'inner' when no innerProtocol 1393 is set. 1394 """ 1395 a = amp.AMP() 1396 self.assertEqual(repr(a), f"<AMP at 0x{id(a):x}>") 1397 1398 @skipIf(skipSSL, "SSL not available") 1399 def test_simpleSSLRepr(self): 1400 """ 1401 L{amp._TLSBox.__repr__} returns a string. 1402 """ 1403 self.assertEqual(type(repr(amp._TLSBox())), str) 1404 1405 def test_keyTooLong(self): 1406 """ 1407 Verify that a key that is too long will immediately raise a synchronous 1408 exception. 1409 """ 1410 c, s, p = connectedServerAndClient() 1411 x = "H" * (0xFF + 1) 1412 tl = self.assertRaises(amp.TooLong, c.callRemoteString, b"Hello", **{x: b"hi"}) 1413 self.assertTrue(tl.isKey) 1414 self.assertTrue(tl.isLocal) 1415 self.assertIsNone(tl.keyName) 1416 self.assertEqual(tl.value, x.encode("ascii")) 1417 self.assertIn(str(len(x)), repr(tl)) 1418 self.assertIn("key", repr(tl)) 1419 1420 def test_valueTooLong(self): 1421 """ 1422 Verify that attempting to send value longer than 64k will immediately 1423 raise an exception. 1424 """ 1425 c, s, p = connectedServerAndClient() 1426 x = b"H" * (0xFFFF + 1) 1427 tl = self.assertRaises(amp.TooLong, c.sendHello, x) 1428 p.flush() 1429 self.assertFalse(tl.isKey) 1430 self.assertTrue(tl.isLocal) 1431 self.assertEqual(tl.keyName, b"hello") 1432 self.failUnlessIdentical(tl.value, x) 1433 self.assertIn(str(len(x)), repr(tl)) 1434 self.assertIn("value", repr(tl)) 1435 self.assertIn("hello", repr(tl)) 1436 1437 def test_helloWorldCommand(self): 1438 """ 1439 Verify that a simple command can be sent and its response received with 1440 the high-level value parsing API. 1441 """ 1442 c, s, p = connectedServerAndClient( 1443 ServerClass=SimpleSymmetricCommandProtocol, 1444 ClientClass=SimpleSymmetricCommandProtocol, 1445 ) 1446 L = [] 1447 HELLO = b"world" 1448 c.sendHello(HELLO).addCallback(L.append) 1449 p.flush() 1450 self.assertEqual(L[0]["hello"], HELLO) 1451 1452 def test_helloErrorHandling(self): 1453 """ 1454 Verify that if a known error type is raised and handled, it will be 1455 properly relayed to the other end of the connection and translated into 1456 an exception, and no error will be logged. 1457 """ 1458 L = [] 1459 c, s, p = connectedServerAndClient( 1460 ServerClass=SimpleSymmetricCommandProtocol, 1461 ClientClass=SimpleSymmetricCommandProtocol, 1462 ) 1463 HELLO = b"fuck you" 1464 c.sendHello(HELLO).addErrback(L.append) 1465 p.flush() 1466 L[0].trap(UnfriendlyGreeting) 1467 self.assertEqual(str(L[0].value), "Don't be a dick.") 1468 1469 def test_helloFatalErrorHandling(self): 1470 """ 1471 Verify that if a known, fatal error type is raised and handled, it will 1472 be properly relayed to the other end of the connection and translated 1473 into an exception, no error will be logged, and the connection will be 1474 terminated. 1475 """ 1476 L = [] 1477 c, s, p = connectedServerAndClient( 1478 ServerClass=SimpleSymmetricCommandProtocol, 1479 ClientClass=SimpleSymmetricCommandProtocol, 1480 ) 1481 HELLO = b"die" 1482 c.sendHello(HELLO).addErrback(L.append) 1483 p.flush() 1484 L.pop().trap(DeathThreat) 1485 c.sendHello(HELLO).addErrback(L.append) 1486 p.flush() 1487 L.pop().trap(error.ConnectionDone) 1488 1489 def test_helloNoErrorHandling(self): 1490 """ 1491 Verify that if an unknown error type is raised, it will be relayed to 1492 the other end of the connection and translated into an exception, it 1493 will be logged, and then the connection will be dropped. 1494 """ 1495 L = [] 1496 c, s, p = connectedServerAndClient( 1497 ServerClass=SimpleSymmetricCommandProtocol, 1498 ClientClass=SimpleSymmetricCommandProtocol, 1499 ) 1500 HELLO = THING_I_DONT_UNDERSTAND 1501 c.sendHello(HELLO).addErrback(L.append) 1502 p.flush() 1503 ure = L.pop() 1504 ure.trap(amp.UnknownRemoteError) 1505 c.sendHello(HELLO).addErrback(L.append) 1506 cl = L.pop() 1507 cl.trap(error.ConnectionDone) 1508 # The exception should have been logged. 1509 self.assertTrue(self.flushLoggedErrors(ThingIDontUnderstandError)) 1510 1511 def test_lateAnswer(self): 1512 """ 1513 Verify that a command that does not get answered until after the 1514 connection terminates will not cause any errors. 1515 """ 1516 c, s, p = connectedServerAndClient( 1517 ServerClass=SimpleSymmetricCommandProtocol, 1518 ClientClass=SimpleSymmetricCommandProtocol, 1519 ) 1520 L = [] 1521 c.callRemote(WaitForever).addErrback(L.append) 1522 p.flush() 1523 self.assertEqual(L, []) 1524 s.transport.loseConnection() 1525 p.flush() 1526 L.pop().trap(error.ConnectionDone) 1527 # Just make sure that it doesn't error... 1528 s.waiting.callback({}) 1529 return s.waiting 1530 1531 def test_requiresNoAnswer(self): 1532 """ 1533 Verify that a command that requires no answer is run. 1534 """ 1535 c, s, p = connectedServerAndClient( 1536 ServerClass=SimpleSymmetricCommandProtocol, 1537 ClientClass=SimpleSymmetricCommandProtocol, 1538 ) 1539 HELLO = b"world" 1540 c.callRemote(NoAnswerHello, hello=HELLO) 1541 p.flush() 1542 self.assertTrue(s.greeted) 1543 1544 def test_requiresNoAnswerFail(self): 1545 """ 1546 Verify that commands sent after a failed no-answer request do not complete. 1547 """ 1548 L = [] 1549 c, s, p = connectedServerAndClient( 1550 ServerClass=SimpleSymmetricCommandProtocol, 1551 ClientClass=SimpleSymmetricCommandProtocol, 1552 ) 1553 HELLO = b"fuck you" 1554 c.callRemote(NoAnswerHello, hello=HELLO) 1555 p.flush() 1556 # This should be logged locally. 1557 self.assertTrue(self.flushLoggedErrors(amp.RemoteAmpError)) 1558 HELLO = b"world" 1559 c.callRemote(Hello, hello=HELLO).addErrback(L.append) 1560 p.flush() 1561 L.pop().trap(error.ConnectionDone) 1562 self.assertFalse(s.greeted) 1563 1564 def test_requiresNoAnswerAfterFail(self): 1565 """ 1566 No-answer commands sent after the connection has been torn down do not 1567 return a L{Deferred}. 1568 """ 1569 c, s, p = connectedServerAndClient( 1570 ServerClass=SimpleSymmetricCommandProtocol, 1571 ClientClass=SimpleSymmetricCommandProtocol, 1572 ) 1573 c.transport.loseConnection() 1574 p.flush() 1575 result = c.callRemote(NoAnswerHello, hello=b"ignored") 1576 self.assertIs(result, None) 1577 1578 def test_noAnswerResponderBadAnswer(self): 1579 """ 1580 Verify that responders of requiresAnswer=False commands have to return 1581 a dictionary anyway. 1582 1583 (requiresAnswer is a hint from the _client_ - the server may be called 1584 upon to answer commands in any case, if the client wants to know when 1585 they complete.) 1586 """ 1587 c, s, p = connectedServerAndClient( 1588 ServerClass=BadNoAnswerCommandProtocol, 1589 ClientClass=SimpleSymmetricCommandProtocol, 1590 ) 1591 c.callRemote(NoAnswerHello, hello=b"hello") 1592 p.flush() 1593 le = self.flushLoggedErrors(amp.BadLocalReturn) 1594 self.assertEqual(len(le), 1) 1595 1596 def test_noAnswerResponderAskedForAnswer(self): 1597 """ 1598 Verify that responders with requiresAnswer=False will actually respond 1599 if the client sets requiresAnswer=True. In other words, verify that 1600 requiresAnswer is a hint honored only by the client. 1601 """ 1602 c, s, p = connectedServerAndClient( 1603 ServerClass=NoAnswerCommandProtocol, 1604 ClientClass=SimpleSymmetricCommandProtocol, 1605 ) 1606 L = [] 1607 c.callRemote(Hello, hello=b"Hello!").addCallback(L.append) 1608 p.flush() 1609 self.assertEqual(len(L), 1) 1610 self.assertEqual( 1611 L, [dict(hello=b"Hello!-noanswer", Print=None)] 1612 ) # Optional response argument 1613 1614 def test_ampListCommand(self): 1615 """ 1616 Test encoding of an argument that uses the AmpList encoding. 1617 """ 1618 c, s, p = connectedServerAndClient( 1619 ServerClass=SimpleSymmetricCommandProtocol, 1620 ClientClass=SimpleSymmetricCommandProtocol, 1621 ) 1622 L = [] 1623 c.callRemote(GetList, length=10).addCallback(L.append) 1624 p.flush() 1625 values = L.pop().get("body") 1626 self.assertEqual(values, [{"x": 1}] * 10) 1627 1628 def test_optionalAmpListOmitted(self): 1629 """ 1630 Sending a command with an omitted AmpList argument that is 1631 designated as optional does not raise an InvalidSignature error. 1632 """ 1633 c, s, p = connectedServerAndClient( 1634 ServerClass=SimpleSymmetricCommandProtocol, 1635 ClientClass=SimpleSymmetricCommandProtocol, 1636 ) 1637 L = [] 1638 c.callRemote(DontRejectMe, magicWord="please").addCallback(L.append) 1639 p.flush() 1640 response = L.pop().get("response") 1641 self.assertEqual(response, "list omitted") 1642 1643 def test_optionalAmpListPresent(self): 1644 """ 1645 Sanity check that optional AmpList arguments are processed normally. 1646 """ 1647 c, s, p = connectedServerAndClient( 1648 ServerClass=SimpleSymmetricCommandProtocol, 1649 ClientClass=SimpleSymmetricCommandProtocol, 1650 ) 1651 L = [] 1652 c.callRemote( 1653 DontRejectMe, magicWord="please", list=[{"name": "foo"}] 1654 ).addCallback(L.append) 1655 p.flush() 1656 response = L.pop().get("response") 1657 self.assertEqual(response, "foo accepted") 1658 1659 def test_failEarlyOnArgSending(self): 1660 """ 1661 Verify that if we pass an invalid argument list (omitting an argument), 1662 an exception will be raised. 1663 """ 1664 self.assertRaises(amp.InvalidSignature, Hello) 1665 1666 def test_doubleProtocolSwitch(self): 1667 """ 1668 As a debugging aid, a protocol system should raise a 1669 L{ProtocolSwitched} exception when asked to switch a protocol that is 1670 already switched. 1671 """ 1672 serverDeferred = defer.Deferred() 1673 serverProto = SimpleSymmetricCommandProtocol(serverDeferred) 1674 clientDeferred = defer.Deferred() 1675 clientProto = SimpleSymmetricCommandProtocol(clientDeferred) 1676 c, s, p = connectedServerAndClient( 1677 ServerClass=lambda: serverProto, ClientClass=lambda: clientProto 1678 ) 1679 1680 def switched(result): 1681 self.assertRaises(amp.ProtocolSwitched, c.switchToTestProtocol) 1682 self.testSucceeded = True 1683 1684 c.switchToTestProtocol().addCallback(switched) 1685 p.flush() 1686 self.assertTrue(self.testSucceeded) 1687 1688 def test_protocolSwitch( 1689 self, 1690 switcher=SimpleSymmetricCommandProtocol, 1691 spuriousTraffic=False, 1692 spuriousError=False, 1693 ): 1694 """ 1695 Verify that it is possible to switch to another protocol mid-connection and 1696 send data to it successfully. 1697 """ 1698 self.testSucceeded = False 1699 1700 serverDeferred = defer.Deferred() 1701 serverProto = switcher(serverDeferred) 1702 clientDeferred = defer.Deferred() 1703 clientProto = switcher(clientDeferred) 1704 c, s, p = connectedServerAndClient( 1705 ServerClass=lambda: serverProto, ClientClass=lambda: clientProto 1706 ) 1707 1708 if spuriousTraffic: 1709 wfdr = [] # remote 1710 c.callRemote(WaitForever).addErrback(wfdr.append) 1711 switchDeferred = c.switchToTestProtocol() 1712 if spuriousTraffic: 1713 self.assertRaises(amp.ProtocolSwitched, c.sendHello, b"world") 1714 1715 def cbConnsLost(info): 1716 ((serverSuccess, serverData), (clientSuccess, clientData)) = info 1717 self.assertTrue(serverSuccess) 1718 self.assertTrue(clientSuccess) 1719 self.assertEqual(b"".join(serverData), SWITCH_CLIENT_DATA) 1720 self.assertEqual(b"".join(clientData), SWITCH_SERVER_DATA) 1721 self.testSucceeded = True 1722 1723 def cbSwitch(proto): 1724 return defer.DeferredList([serverDeferred, clientDeferred]).addCallback( 1725 cbConnsLost 1726 ) 1727 1728 switchDeferred.addCallback(cbSwitch) 1729 p.flush() 1730 if serverProto.maybeLater is not None: 1731 serverProto.maybeLater.callback(serverProto.maybeLaterProto) 1732 p.flush() 1733 if spuriousTraffic: 1734 # switch is done here; do this here to make sure that if we're 1735 # going to corrupt the connection, we do it before it's closed. 1736 if spuriousError: 1737 s.waiting.errback( 1738 amp.RemoteAmpError( 1739 b"SPURIOUS", "Here's some traffic in the form of an error." 1740 ) 1741 ) 1742 else: 1743 s.waiting.callback({}) 1744 p.flush() 1745 c.transport.loseConnection() # close it 1746 p.flush() 1747 self.assertTrue(self.testSucceeded) 1748 1749 def test_protocolSwitchDeferred(self): 1750 """ 1751 Verify that protocol-switching even works if the value returned from 1752 the command that does the switch is deferred. 1753 """ 1754 return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtocol) 1755 1756 def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol): 1757 """ 1758 Verify that if we try to switch protocols and it fails, the connection 1759 stays up and we can go back to speaking AMP. 1760 """ 1761 self.testSucceeded = False 1762 1763 serverDeferred = defer.Deferred() 1764 serverProto = switcher(serverDeferred) 1765 clientDeferred = defer.Deferred() 1766 clientProto = switcher(clientDeferred) 1767 c, s, p = connectedServerAndClient( 1768 ServerClass=lambda: serverProto, ClientClass=lambda: clientProto 1769 ) 1770 L = [] 1771 c.switchToTestProtocol(fail=True).addErrback(L.append) 1772 p.flush() 1773 L.pop().trap(UnknownProtocol) 1774 self.assertFalse(self.testSucceeded) 1775 # It's a known error, so let's send a "hello" on the same connection; 1776 # it should work. 1777 c.sendHello(b"world").addCallback(L.append) 1778 p.flush() 1779 self.assertEqual(L.pop()["hello"], b"world") 1780 1781 def test_trafficAfterSwitch(self): 1782 """ 1783 Verify that attempts to send traffic after a switch will not corrupt 1784 the nested protocol. 1785 """ 1786 return self.test_protocolSwitch(spuriousTraffic=True) 1787 1788 def test_errorAfterSwitch(self): 1789 """ 1790 Returning an error after a protocol switch should record the underlying 1791 error. 1792 """ 1793 return self.test_protocolSwitch(spuriousTraffic=True, spuriousError=True) 1794 1795 def test_quitBoxQuits(self): 1796 """ 1797 Verify that commands with a responseType of QuitBox will in fact 1798 terminate the connection. 1799 """ 1800 c, s, p = connectedServerAndClient( 1801 ServerClass=SimpleSymmetricCommandProtocol, 1802 ClientClass=SimpleSymmetricCommandProtocol, 1803 ) 1804 1805 L = [] 1806 HELLO = b"world" 1807 GOODBYE = b"everyone" 1808 c.sendHello(HELLO).addCallback(L.append) 1809 p.flush() 1810 self.assertEqual(L.pop()["hello"], HELLO) 1811 c.callRemote(Goodbye).addCallback(L.append) 1812 p.flush() 1813 self.assertEqual(L.pop()["goodbye"], GOODBYE) 1814 c.sendHello(HELLO).addErrback(L.append) 1815 L.pop().trap(error.ConnectionDone) 1816 1817 def test_basicLiteralEmit(self): 1818 """ 1819 Verify that the command dictionaries for a callRemoteN look correct 1820 after being serialized and parsed. 1821 """ 1822 c, s, p = connectedServerAndClient() 1823 L = [] 1824 s.ampBoxReceived = L.append 1825 c.callRemote( 1826 Hello, 1827 hello=b"hello test", 1828 mixedCase=b"mixed case arg test", 1829 dash_arg=b"x", 1830 underscore_arg=b"y", 1831 ) 1832 p.flush() 1833 self.assertEqual(len(L), 1) 1834 for k, v in [ 1835 (b"_command", Hello.commandName), 1836 (b"hello", b"hello test"), 1837 (b"mixedCase", b"mixed case arg test"), 1838 (b"dash-arg", b"x"), 1839 (b"underscore_arg", b"y"), 1840 ]: 1841 self.assertEqual(L[-1].pop(k), v) 1842 L[-1].pop(b"_ask") 1843 self.assertEqual(L[-1], {}) 1844 1845 def test_basicStructuredEmit(self): 1846 """ 1847 Verify that a call similar to basicLiteralEmit's is handled properly with 1848 high-level quoting and passing to Python methods, and that argument 1849 names are correctly handled. 1850 """ 1851 L = [] 1852 1853 class StructuredHello(amp.AMP): 1854 def h(self, *a, **k): 1855 L.append((a, k)) 1856 return dict(hello=b"aaa") 1857 1858 Hello.responder(h) 1859 1860 c, s, p = connectedServerAndClient(ServerClass=StructuredHello) 1861 c.callRemote( 1862 Hello, 1863 hello=b"hello test", 1864 mixedCase=b"mixed case arg test", 1865 dash_arg=b"x", 1866 underscore_arg=b"y", 1867 ).addCallback(L.append) 1868 p.flush() 1869 self.assertEqual(len(L), 2) 1870 self.assertEqual( 1871 L[0], 1872 ( 1873 (), 1874 dict( 1875 hello=b"hello test", 1876 mixedCase=b"mixed case arg test", 1877 dash_arg=b"x", 1878 underscore_arg=b"y", 1879 From=s.transport.getPeer(), 1880 # XXX - should optional arguments just not be passed? 1881 # passing None seems a little odd, looking at the way it 1882 # turns out here... -glyph 1883 Print=None, 1884 optional=None, 1885 ), 1886 ), 1887 ) 1888 self.assertEqual(L[1], dict(Print=None, hello=b"aaa")) 1889 1890 1891class PretendRemoteCertificateAuthority: 1892 def checkIsPretendRemote(self): 1893 return True 1894 1895 1896class IOSimCert: 1897 verifyCount = 0 1898 1899 def options(self, *ign): 1900 return self 1901 1902 def iosimVerify(self, otherCert): 1903 """ 1904 This isn't a real certificate, and wouldn't work on a real socket, but 1905 iosim specifies a different API so that we don't have to do any crypto 1906 math to demonstrate that the right functions get called in the right 1907 places. 1908 """ 1909 assert otherCert is self 1910 self.verifyCount += 1 1911 return True 1912 1913 1914class OKCert(IOSimCert): 1915 def options(self, x): 1916 assert x.checkIsPretendRemote() 1917 return self 1918 1919 1920class GrumpyCert(IOSimCert): 1921 def iosimVerify(self, otherCert): 1922 self.verifyCount += 1 1923 return False 1924 1925 1926class DroppyCert(IOSimCert): 1927 def __init__(self, toDrop): 1928 self.toDrop = toDrop 1929 1930 def iosimVerify(self, otherCert): 1931 self.verifyCount += 1 1932 self.toDrop.loseConnection() 1933 return True 1934 1935 1936class SecurableProto(FactoryNotifier): 1937 1938 factory = None 1939 1940 def verifyFactory(self): 1941 return [PretendRemoteCertificateAuthority()] 1942 1943 def getTLSVars(self): 1944 cert = self.certFactory() 1945 verify = self.verifyFactory() 1946 return dict(tls_localCertificate=cert, tls_verifyAuthorities=verify) 1947 1948 amp.StartTLS.responder(getTLSVars) 1949 1950 1951@skipIf(skipSSL, "SSL not available") 1952@skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor") 1953class TLSTests(TestCase): 1954 def test_startingTLS(self): 1955 """ 1956 Verify that starting TLS and succeeding at handshaking sends all the 1957 notifications to all the right places. 1958 """ 1959 cli, svr, p = connectedServerAndClient( 1960 ServerClass=SecurableProto, ClientClass=SecurableProto 1961 ) 1962 1963 okc = OKCert() 1964 svr.certFactory = lambda: okc 1965 1966 cli.callRemote( 1967 amp.StartTLS, 1968 tls_localCertificate=okc, 1969 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()], 1970 ) 1971 1972 # let's buffer something to be delivered securely 1973 L = [] 1974 cli.callRemote(SecuredPing).addCallback(L.append) 1975 p.flush() 1976 # once for client once for server 1977 self.assertEqual(okc.verifyCount, 2) 1978 L = [] 1979 cli.callRemote(SecuredPing).addCallback(L.append) 1980 p.flush() 1981 self.assertEqual(L[0], {"pinged": True}) 1982 1983 def test_startTooManyTimes(self): 1984 """ 1985 Verify that the protocol will complain if we attempt to renegotiate TLS, 1986 which we don't support. 1987 """ 1988 cli, svr, p = connectedServerAndClient( 1989 ServerClass=SecurableProto, ClientClass=SecurableProto 1990 ) 1991 1992 okc = OKCert() 1993 svr.certFactory = lambda: okc 1994 1995 cli.callRemote( 1996 amp.StartTLS, 1997 tls_localCertificate=okc, 1998 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()], 1999 ) 2000 p.flush() 2001 cli.noPeerCertificate = True # this is totally fake 2002 self.assertRaises( 2003 amp.OnlyOneTLS, 2004 cli.callRemote, 2005 amp.StartTLS, 2006 tls_localCertificate=okc, 2007 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()], 2008 ) 2009 2010 def test_negotiationFailed(self): 2011 """ 2012 Verify that starting TLS and failing on both sides at handshaking sends 2013 notifications to all the right places and terminates the connection. 2014 """ 2015 2016 badCert = GrumpyCert() 2017 2018 cli, svr, p = connectedServerAndClient( 2019 ServerClass=SecurableProto, ClientClass=SecurableProto 2020 ) 2021 svr.certFactory = lambda: badCert 2022 2023 cli.callRemote(amp.StartTLS, tls_localCertificate=badCert) 2024 2025 p.flush() 2026 # once for client once for server - but both fail 2027 self.assertEqual(badCert.verifyCount, 2) 2028 d = cli.callRemote(SecuredPing) 2029 p.flush() 2030 self.assertFailure(d, iosim.NativeOpenSSLError) 2031 2032 def test_negotiationFailedByClosing(self): 2033 """ 2034 Verify that starting TLS and failing by way of a lost connection 2035 notices that it is probably an SSL problem. 2036 """ 2037 2038 cli, svr, p = connectedServerAndClient( 2039 ServerClass=SecurableProto, ClientClass=SecurableProto 2040 ) 2041 droppyCert = DroppyCert(svr.transport) 2042 svr.certFactory = lambda: droppyCert 2043 2044 cli.callRemote(amp.StartTLS, tls_localCertificate=droppyCert) 2045 2046 p.flush() 2047 2048 self.assertEqual(droppyCert.verifyCount, 2) 2049 2050 d = cli.callRemote(SecuredPing) 2051 p.flush() 2052 2053 # it might be a good idea to move this exception somewhere more 2054 # reasonable. 2055 self.assertFailure(d, error.PeerVerifyError) 2056 2057 2058class TLSNotAvailableTests(TestCase): 2059 """ 2060 Tests what happened when ssl is not available in current installation. 2061 """ 2062 2063 def setUp(self): 2064 """ 2065 Disable ssl in amp. 2066 """ 2067 self.ssl = amp.ssl 2068 amp.ssl = None 2069 2070 def tearDown(self): 2071 """ 2072 Restore ssl module. 2073 """ 2074 amp.ssl = self.ssl 2075 2076 def test_callRemoteError(self): 2077 """ 2078 Check that callRemote raises an exception when called with a 2079 L{amp.StartTLS}. 2080 """ 2081 cli, svr, p = connectedServerAndClient( 2082 ServerClass=SecurableProto, ClientClass=SecurableProto 2083 ) 2084 2085 okc = OKCert() 2086 svr.certFactory = lambda: okc 2087 2088 return self.assertFailure( 2089 cli.callRemote( 2090 amp.StartTLS, 2091 tls_localCertificate=okc, 2092 tls_verifyAuthorities=[PretendRemoteCertificateAuthority()], 2093 ), 2094 RuntimeError, 2095 ) 2096 2097 def test_messageReceivedError(self): 2098 """ 2099 When a client with SSL enabled talks to a server without SSL, it 2100 should return a meaningful error. 2101 """ 2102 svr = SecurableProto() 2103 okc = OKCert() 2104 svr.certFactory = lambda: okc 2105 box = amp.Box() 2106 box[b"_command"] = b"StartTLS" 2107 box[b"_ask"] = b"1" 2108 boxes = [] 2109 svr.sendBox = boxes.append 2110 svr.makeConnection(StringTransport()) 2111 svr.ampBoxReceived(box) 2112 self.assertEqual( 2113 boxes, 2114 [ 2115 { 2116 b"_error_code": b"TLS_ERROR", 2117 b"_error": b"1", 2118 b"_error_description": b"TLS not available", 2119 } 2120 ], 2121 ) 2122 2123 2124class InheritedError(Exception): 2125 """ 2126 This error is used to check inheritance. 2127 """ 2128 2129 2130class OtherInheritedError(Exception): 2131 """ 2132 This is a distinct error for checking inheritance. 2133 """ 2134 2135 2136class BaseCommand(amp.Command): 2137 """ 2138 This provides a command that will be subclassed. 2139 """ 2140 2141 errors: Dict[Type[Exception], bytes] = {InheritedError: b"INHERITED_ERROR"} 2142 2143 2144class InheritedCommand(BaseCommand): 2145 """ 2146 This is a command which subclasses another command but does not override 2147 anything. 2148 """ 2149 2150 2151class AddErrorsCommand(BaseCommand): 2152 """ 2153 This is a command which subclasses another command but adds errors to the 2154 list. 2155 """ 2156 2157 arguments = [(b"other", amp.Boolean())] 2158 errors: Dict[Type[Exception], bytes] = { 2159 OtherInheritedError: b"OTHER_INHERITED_ERROR" 2160 } 2161 2162 2163class NormalCommandProtocol(amp.AMP): 2164 """ 2165 This is a protocol which responds to L{BaseCommand}, and is used to test 2166 that inheritance does not interfere with the normal handling of errors. 2167 """ 2168 2169 def resp(self): 2170 raise InheritedError() 2171 2172 BaseCommand.responder(resp) 2173 2174 2175class InheritedCommandProtocol(amp.AMP): 2176 """ 2177 This is a protocol which responds to L{InheritedCommand}, and is used to 2178 test that inherited commands inherit their bases' errors if they do not 2179 respond to any of their own. 2180 """ 2181 2182 def resp(self): 2183 raise InheritedError() 2184 2185 InheritedCommand.responder(resp) 2186 2187 2188class AddedCommandProtocol(amp.AMP): 2189 """ 2190 This is a protocol which responds to L{AddErrorsCommand}, and is used to 2191 test that inherited commands can add their own new types of errors, but 2192 still respond in the same way to their parents types of errors. 2193 """ 2194 2195 def resp(self, other): 2196 if other: 2197 raise OtherInheritedError() 2198 else: 2199 raise InheritedError() 2200 2201 AddErrorsCommand.responder(resp) 2202 2203 2204class CommandInheritanceTests(TestCase): 2205 """ 2206 These tests verify that commands inherit error conditions properly. 2207 """ 2208 2209 def errorCheck(self, err, proto, cmd, **kw): 2210 """ 2211 Check that the appropriate kind of error is raised when a given command 2212 is sent to a given protocol. 2213 """ 2214 c, s, p = connectedServerAndClient(ServerClass=proto, ClientClass=proto) 2215 d = c.callRemote(cmd, **kw) 2216 d2 = self.failUnlessFailure(d, err) 2217 p.flush() 2218 return d2 2219 2220 def test_basicErrorPropagation(self): 2221 """ 2222 Verify that errors specified in a superclass are respected normally 2223 even if it has subclasses. 2224 """ 2225 return self.errorCheck(InheritedError, NormalCommandProtocol, BaseCommand) 2226 2227 def test_inheritedErrorPropagation(self): 2228 """ 2229 Verify that errors specified in a superclass command are propagated to 2230 its subclasses. 2231 """ 2232 return self.errorCheck( 2233 InheritedError, InheritedCommandProtocol, InheritedCommand 2234 ) 2235 2236 def test_inheritedErrorAddition(self): 2237 """ 2238 Verify that new errors specified in a subclass of an existing command 2239 are honored even if the superclass defines some errors. 2240 """ 2241 return self.errorCheck( 2242 OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=True 2243 ) 2244 2245 def test_additionWithOriginalError(self): 2246 """ 2247 Verify that errors specified in a command's superclass are respected 2248 even if that command defines new errors itself. 2249 """ 2250 return self.errorCheck( 2251 InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False 2252 ) 2253 2254 2255def _loseAndPass(err, proto): 2256 # be specific, pass on the error to the client. 2257 err.trap(error.ConnectionLost, error.ConnectionDone) 2258 del proto.connectionLost 2259 proto.connectionLost(err) 2260 2261 2262class LiveFireBase: 2263 """ 2264 Utility for connected reactor-using tests. 2265 """ 2266 2267 def setUp(self): 2268 """ 2269 Create an amp server and connect a client to it. 2270 """ 2271 from twisted.internet import reactor 2272 2273 self.serverFactory = protocol.ServerFactory() 2274 self.serverFactory.protocol = self.serverProto 2275 self.clientFactory = protocol.ClientFactory() 2276 self.clientFactory.protocol = self.clientProto 2277 self.clientFactory.onMade = defer.Deferred() 2278 self.serverFactory.onMade = defer.Deferred() 2279 self.serverPort = reactor.listenTCP(0, self.serverFactory) 2280 self.addCleanup(self.serverPort.stopListening) 2281 self.clientConn = reactor.connectTCP( 2282 "127.0.0.1", self.serverPort.getHost().port, self.clientFactory 2283 ) 2284 self.addCleanup(self.clientConn.disconnect) 2285 2286 def getProtos(rlst): 2287 self.cli = self.clientFactory.theProto 2288 self.svr = self.serverFactory.theProto 2289 2290 dl = defer.DeferredList([self.clientFactory.onMade, self.serverFactory.onMade]) 2291 return dl.addCallback(getProtos) 2292 2293 def tearDown(self): 2294 """ 2295 Cleanup client and server connections, and check the error got at 2296 C{connectionLost}. 2297 """ 2298 L = [] 2299 for conn in self.cli, self.svr: 2300 if conn.transport is not None: 2301 # depend on amp's function connection-dropping behavior 2302 d = defer.Deferred().addErrback(_loseAndPass, conn) 2303 conn.connectionLost = d.errback 2304 conn.transport.loseConnection() 2305 L.append(d) 2306 return defer.gatherResults(L).addErrback(lambda first: first.value.subFailure) 2307 2308 2309def show(x): 2310 import sys 2311 2312 sys.stdout.write(x + "\n") 2313 sys.stdout.flush() 2314 2315 2316def tempSelfSigned(): 2317 from twisted.internet import ssl 2318 2319 sharedDN = ssl.DN(CN="shared") 2320 key = ssl.KeyPair.generate() 2321 cr = key.certificateRequest(sharedDN) 2322 sscrd = key.signCertificateRequest(sharedDN, cr, lambda dn: True, 1234567) 2323 cert = key.newCertificate(sscrd) 2324 return cert 2325 2326 2327if ssl is not None: 2328 tempcert = tempSelfSigned() 2329 2330 2331@skipIf(skipSSL, "SSL not available") 2332@skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor") 2333class LiveFireTLSTests(LiveFireBase, TestCase): 2334 2335 clientProto = SecurableProto 2336 serverProto = SecurableProto 2337 2338 def test_liveFireCustomTLS(self): 2339 """ 2340 Using real, live TLS, actually negotiate a connection. 2341 2342 This also looks at the 'peerCertificate' attribute's correctness, since 2343 that's actually loaded using OpenSSL calls, but the main purpose is to 2344 make sure that we didn't miss anything obvious in iosim about TLS 2345 negotiations. 2346 """ 2347 2348 cert = tempcert 2349 2350 self.svr.verifyFactory = lambda: [cert] 2351 self.svr.certFactory = lambda: cert 2352 # only needed on the server, we specify the client below. 2353 2354 def secured(rslt): 2355 x = cert.digest() 2356 2357 def pinged(rslt2): 2358 # Interesting. OpenSSL won't even _tell_ us about the peer 2359 # cert until we negotiate. we should be able to do this in 2360 # 'secured' instead, but it looks like we can't. I think this 2361 # is a bug somewhere far deeper than here. 2362 self.assertEqual(x, self.cli.hostCertificate.digest()) 2363 self.assertEqual(x, self.cli.peerCertificate.digest()) 2364 self.assertEqual(x, self.svr.hostCertificate.digest()) 2365 self.assertEqual(x, self.svr.peerCertificate.digest()) 2366 2367 return self.cli.callRemote(SecuredPing).addCallback(pinged) 2368 2369 return self.cli.callRemote( 2370 amp.StartTLS, tls_localCertificate=cert, tls_verifyAuthorities=[cert] 2371 ).addCallback(secured) 2372 2373 2374class SlightlySmartTLS(SimpleSymmetricCommandProtocol): 2375 """ 2376 Specific implementation of server side protocol with different 2377 management of TLS. 2378 """ 2379 2380 def getTLSVars(self): 2381 """ 2382 @return: the global C{tempcert} certificate as local certificate. 2383 """ 2384 return dict(tls_localCertificate=tempcert) 2385 2386 amp.StartTLS.responder(getTLSVars) 2387 2388 2389@skipIf(skipSSL, "SSL not available") 2390@skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor") 2391class PlainVanillaLiveFireTests(LiveFireBase, TestCase): 2392 2393 clientProto = SimpleSymmetricCommandProtocol 2394 serverProto = SimpleSymmetricCommandProtocol 2395 2396 def test_liveFireDefaultTLS(self): 2397 """ 2398 Verify that out of the box, we can start TLS to at least encrypt the 2399 connection, even if we don't have any certificates to use. 2400 """ 2401 2402 def secured(result): 2403 return self.cli.callRemote(SecuredPing) 2404 2405 return self.cli.callRemote(amp.StartTLS).addCallback(secured) 2406 2407 2408@skipIf(skipSSL, "SSL not available") 2409@skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor") 2410class WithServerTLSVerificationTests(LiveFireBase, TestCase): 2411 2412 clientProto = SimpleSymmetricCommandProtocol 2413 serverProto = SlightlySmartTLS 2414 2415 def test_anonymousVerifyingClient(self): 2416 """ 2417 Verify that anonymous clients can verify server certificates. 2418 """ 2419 2420 def secured(result): 2421 return self.cli.callRemote(SecuredPing) 2422 2423 return self.cli.callRemote( 2424 amp.StartTLS, tls_verifyAuthorities=[tempcert] 2425 ).addCallback(secured) 2426 2427 2428class ProtocolIncludingArgument(amp.Argument): 2429 """ 2430 An L{amp.Argument} which encodes its parser and serializer 2431 arguments *including the protocol* into its parsed and serialized 2432 forms. 2433 """ 2434 2435 def fromStringProto(self, string, protocol): 2436 """ 2437 Don't decode anything; just return all possible information. 2438 2439 @return: A two-tuple of the input string and the protocol. 2440 """ 2441 return (string, protocol) 2442 2443 def toStringProto(self, obj, protocol): 2444 """ 2445 Encode identifying information about L{object} and protocol 2446 into a string for later verification. 2447 2448 @type obj: L{object} 2449 @type protocol: L{amp.AMP} 2450 """ 2451 ident = "%d:%d" % (id(obj), id(protocol)) 2452 return ident.encode("ascii") 2453 2454 2455class ProtocolIncludingCommand(amp.Command): 2456 """ 2457 A command that has argument and response schemas which use 2458 L{ProtocolIncludingArgument}. 2459 """ 2460 2461 arguments = [(b"weird", ProtocolIncludingArgument())] 2462 response = [(b"weird", ProtocolIncludingArgument())] 2463 2464 2465class MagicSchemaCommand(amp.Command): 2466 """ 2467 A command which overrides L{parseResponse}, L{parseArguments}, and 2468 L{makeResponse}. 2469 """ 2470 2471 @classmethod 2472 def parseResponse(self, strings, protocol): 2473 """ 2474 Don't do any parsing, just jam the input strings and protocol 2475 onto the C{protocol.parseResponseArguments} attribute as a 2476 two-tuple. Return the original strings. 2477 """ 2478 protocol.parseResponseArguments = (strings, protocol) 2479 return strings 2480 2481 @classmethod 2482 def parseArguments(cls, strings, protocol): 2483 """ 2484 Don't do any parsing, just jam the input strings and protocol 2485 onto the C{protocol.parseArgumentsArguments} attribute as a 2486 two-tuple. Return the original strings. 2487 """ 2488 protocol.parseArgumentsArguments = (strings, protocol) 2489 return strings 2490 2491 @classmethod 2492 def makeArguments(cls, objects, protocol): 2493 """ 2494 Don't do any serializing, just jam the input strings and protocol 2495 onto the C{protocol.makeArgumentsArguments} attribute as a 2496 two-tuple. Return the original strings. 2497 """ 2498 protocol.makeArgumentsArguments = (objects, protocol) 2499 return objects 2500 2501 2502class NoNetworkProtocol(amp.AMP): 2503 """ 2504 An L{amp.AMP} subclass which overrides private methods to avoid 2505 testing the network. It also provides a responder for 2506 L{MagicSchemaCommand} that does nothing, so that tests can test 2507 aspects of the interaction of L{amp.Command}s and L{amp.AMP}. 2508 2509 @ivar parseArgumentsArguments: Arguments that have been passed to any 2510 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by 2511 this protocol. 2512 2513 @ivar parseResponseArguments: Responses that have been returned from a 2514 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by 2515 this protocol. 2516 2517 @ivar makeArgumentsArguments: Arguments that have been serialized by any 2518 L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by 2519 this protocol. 2520 """ 2521 2522 def _sendBoxCommand(self, commandName, strings, requiresAnswer): 2523 """ 2524 Return a Deferred which fires with the original strings. 2525 """ 2526 return defer.succeed(strings) 2527 2528 MagicSchemaCommand.responder(lambda s, weird: {}) 2529 2530 2531class MyBox(dict): 2532 """ 2533 A unique dict subclass. 2534 """ 2535 2536 2537class ProtocolIncludingCommandWithDifferentCommandType(ProtocolIncludingCommand): 2538 """ 2539 A L{ProtocolIncludingCommand} subclass whose commandType is L{MyBox} 2540 """ 2541 2542 commandType = MyBox # type: ignore[assignment] 2543 2544 2545class CommandTests(TestCase): 2546 """ 2547 Tests for L{amp.Argument} and L{amp.Command}. 2548 """ 2549 2550 def test_argumentInterface(self): 2551 """ 2552 L{Argument} instances provide L{amp.IArgumentType}. 2553 """ 2554 self.assertTrue(verifyObject(amp.IArgumentType, amp.Argument())) 2555 2556 def test_parseResponse(self): 2557 """ 2558 There should be a class method of Command which accepts a 2559 mapping of argument names to serialized forms and returns a 2560 similar mapping whose values have been parsed via the 2561 Command's response schema. 2562 """ 2563 protocol = object() 2564 result = b"whatever" 2565 strings = {b"weird": result} 2566 self.assertEqual( 2567 ProtocolIncludingCommand.parseResponse(strings, protocol), 2568 {"weird": (result, protocol)}, 2569 ) 2570 2571 def test_callRemoteCallsParseResponse(self): 2572 """ 2573 Making a remote call on a L{amp.Command} subclass which 2574 overrides the C{parseResponse} method should call that 2575 C{parseResponse} method to get the response. 2576 """ 2577 client = NoNetworkProtocol() 2578 thingy = b"weeoo" 2579 response = client.callRemote(MagicSchemaCommand, weird=thingy) 2580 2581 def gotResponse(ign): 2582 self.assertEqual(client.parseResponseArguments, ({"weird": thingy}, client)) 2583 2584 response.addCallback(gotResponse) 2585 return response 2586 2587 def test_parseArguments(self): 2588 """ 2589 There should be a class method of L{amp.Command} which accepts 2590 a mapping of argument names to serialized forms and returns a 2591 similar mapping whose values have been parsed via the 2592 command's argument schema. 2593 """ 2594 protocol = object() 2595 result = b"whatever" 2596 strings = {b"weird": result} 2597 self.assertEqual( 2598 ProtocolIncludingCommand.parseArguments(strings, protocol), 2599 {"weird": (result, protocol)}, 2600 ) 2601 2602 def test_responderCallsParseArguments(self): 2603 """ 2604 Making a remote call on a L{amp.Command} subclass which 2605 overrides the C{parseArguments} method should call that 2606 C{parseArguments} method to get the arguments. 2607 """ 2608 protocol = NoNetworkProtocol() 2609 responder = protocol.locateResponder(MagicSchemaCommand.commandName) 2610 argument = object() 2611 response = responder(dict(weird=argument)) 2612 response.addCallback( 2613 lambda ign: self.assertEqual( 2614 protocol.parseArgumentsArguments, ({"weird": argument}, protocol) 2615 ) 2616 ) 2617 return response 2618 2619 def test_makeArguments(self): 2620 """ 2621 There should be a class method of L{amp.Command} which accepts 2622 a mapping of argument names to objects and returns a similar 2623 mapping whose values have been serialized via the command's 2624 argument schema. 2625 """ 2626 protocol = object() 2627 argument = object() 2628 objects = {"weird": argument} 2629 ident = "%d:%d" % (id(argument), id(protocol)) 2630 self.assertEqual( 2631 ProtocolIncludingCommand.makeArguments(objects, protocol), 2632 {b"weird": ident.encode("ascii")}, 2633 ) 2634 2635 def test_makeArgumentsUsesCommandType(self): 2636 """ 2637 L{amp.Command.makeArguments}'s return type should be the type 2638 of the result of L{amp.Command.commandType}. 2639 """ 2640 protocol = object() 2641 objects = {"weird": b"whatever"} 2642 2643 result = ProtocolIncludingCommandWithDifferentCommandType.makeArguments( 2644 objects, protocol 2645 ) 2646 self.assertIs(type(result), MyBox) 2647 2648 def test_callRemoteCallsMakeArguments(self): 2649 """ 2650 Making a remote call on a L{amp.Command} subclass which 2651 overrides the C{makeArguments} method should call that 2652 C{makeArguments} method to get the response. 2653 """ 2654 client = NoNetworkProtocol() 2655 argument = object() 2656 response = client.callRemote(MagicSchemaCommand, weird=argument) 2657 2658 def gotResponse(ign): 2659 self.assertEqual( 2660 client.makeArgumentsArguments, ({"weird": argument}, client) 2661 ) 2662 2663 response.addCallback(gotResponse) 2664 return response 2665 2666 def test_extraArgumentsDisallowed(self): 2667 """ 2668 L{Command.makeArguments} raises L{amp.InvalidSignature} if the objects 2669 dictionary passed to it includes a key which does not correspond to the 2670 Python identifier for a defined argument. 2671 """ 2672 self.assertRaises( 2673 amp.InvalidSignature, 2674 Hello.makeArguments, 2675 dict(hello="hello", bogusArgument=object()), 2676 None, 2677 ) 2678 2679 def test_wireSpellingDisallowed(self): 2680 """ 2681 If a command argument conflicts with a Python keyword, the 2682 untransformed argument name is not allowed as a key in the dictionary 2683 passed to L{Command.makeArguments}. If it is supplied, 2684 L{amp.InvalidSignature} is raised. 2685 2686 This may be a pointless implementation restriction which may be lifted. 2687 The current behavior is tested to verify that such arguments are not 2688 silently dropped on the floor (the previous behavior). 2689 """ 2690 self.assertRaises( 2691 amp.InvalidSignature, 2692 Hello.makeArguments, 2693 dict(hello="required", **{"print": "print value"}), 2694 None, 2695 ) 2696 2697 def test_commandNameDefaultsToClassNameAsByteString(self): 2698 """ 2699 A L{Command} subclass without a defined C{commandName} that's 2700 not a byte string. 2701 """ 2702 2703 class NewCommand(amp.Command): 2704 """ 2705 A new command. 2706 """ 2707 2708 self.assertEqual(b"NewCommand", NewCommand.commandName) 2709 2710 def test_commandNameMustBeAByteString(self): 2711 """ 2712 A L{Command} subclass cannot be defined with a C{commandName} that's 2713 not a byte string. 2714 """ 2715 error = self.assertRaises( 2716 TypeError, type, "NewCommand", (amp.Command,), {"commandName": "FOO"} 2717 ) 2718 self.assertRegex( 2719 str(error), "^Command names must be byte strings, got: u?'FOO'$" 2720 ) 2721 2722 def test_commandArgumentsMustBeNamedWithByteStrings(self): 2723 """ 2724 A L{Command} subclass's C{arguments} must have byte string names. 2725 """ 2726 error = self.assertRaises( 2727 TypeError, 2728 type, 2729 "NewCommand", 2730 (amp.Command,), 2731 {"arguments": [("foo", None)]}, 2732 ) 2733 self.assertRegex( 2734 str(error), "^Argument names must be byte strings, got: u?'foo'$" 2735 ) 2736 2737 def test_commandResponseMustBeNamedWithByteStrings(self): 2738 """ 2739 A L{Command} subclass's C{response} must have byte string names. 2740 """ 2741 error = self.assertRaises( 2742 TypeError, type, "NewCommand", (amp.Command,), {"response": [("foo", None)]} 2743 ) 2744 self.assertRegex( 2745 str(error), "^Response names must be byte strings, got: u?'foo'$" 2746 ) 2747 2748 def test_commandErrorsIsConvertedToDict(self): 2749 """ 2750 A L{Command} subclass's C{errors} is coerced into a C{dict}. 2751 """ 2752 2753 class NewCommand(amp.Command): 2754 errors = [(ZeroDivisionError, b"ZDE")] 2755 2756 self.assertEqual({ZeroDivisionError: b"ZDE"}, NewCommand.errors) 2757 2758 def test_commandErrorsMustUseBytesForOnWireRepresentation(self): 2759 """ 2760 A L{Command} subclass's C{errors} must map exceptions to byte strings. 2761 """ 2762 error = self.assertRaises( 2763 TypeError, 2764 type, 2765 "NewCommand", 2766 (amp.Command,), 2767 {"errors": [(ZeroDivisionError, "foo")]}, 2768 ) 2769 self.assertRegex(str(error), "^Error names must be byte strings, got: u?'foo'$") 2770 2771 def test_commandFatalErrorsIsConvertedToDict(self): 2772 """ 2773 A L{Command} subclass's C{fatalErrors} is coerced into a C{dict}. 2774 """ 2775 2776 class NewCommand(amp.Command): 2777 fatalErrors = [(ZeroDivisionError, b"ZDE")] 2778 2779 self.assertEqual({ZeroDivisionError: b"ZDE"}, NewCommand.fatalErrors) 2780 2781 def test_commandFatalErrorsMustUseBytesForOnWireRepresentation(self): 2782 """ 2783 A L{Command} subclass's C{fatalErrors} must map exceptions to byte 2784 strings. 2785 """ 2786 error = self.assertRaises( 2787 TypeError, 2788 type, 2789 "NewCommand", 2790 (amp.Command,), 2791 {"fatalErrors": [(ZeroDivisionError, "foo")]}, 2792 ) 2793 self.assertRegex( 2794 str(error), "^Fatal error names must be byte strings, " "got: u?'foo'$" 2795 ) 2796 2797 2798class ListOfTestsMixin: 2799 """ 2800 Base class for testing L{ListOf}, a parameterized zero-or-more argument 2801 type. 2802 2803 @ivar elementType: Subclasses should set this to an L{Argument} 2804 instance. The tests will make a L{ListOf} using this. 2805 2806 @ivar strings: Subclasses should set this to a dictionary mapping some 2807 number of keys -- as BYTE strings -- to the correct serialized form 2808 for some example values. These should agree with what L{elementType} 2809 produces/accepts. 2810 2811 @ivar objects: Subclasses should set this to a dictionary with the same 2812 keys as C{strings} -- as NATIVE strings -- and with values which are 2813 the lists which should serialize to the values in the C{strings} 2814 dictionary. 2815 """ 2816 2817 def test_toBox(self): 2818 """ 2819 L{ListOf.toBox} extracts the list of objects from the C{objects} 2820 dictionary passed to it, using the C{name} key also passed to it, 2821 serializes each of the elements in that list using the L{Argument} 2822 instance previously passed to its initializer, combines the serialized 2823 results, and inserts the result into the C{strings} dictionary using 2824 the same C{name} key. 2825 """ 2826 stringList = amp.ListOf(self.elementType) 2827 strings = amp.AmpBox() 2828 for key in self.objects: 2829 stringList.toBox(key.encode("ascii"), strings, self.objects.copy(), None) 2830 self.assertEqual(strings, self.strings) 2831 2832 def test_fromBox(self): 2833 """ 2834 L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}. 2835 """ 2836 stringList = amp.ListOf(self.elementType) 2837 objects = {} 2838 for key in self.strings: 2839 stringList.fromBox(key, self.strings.copy(), objects, None) 2840 self.assertEqual(objects, self.objects) 2841 2842 2843class ListOfStringsTests(TestCase, ListOfTestsMixin): 2844 """ 2845 Tests for L{ListOf} combined with L{amp.String}. 2846 """ 2847 2848 elementType = amp.String() 2849 2850 strings = { 2851 b"empty": b"", 2852 b"single": b"\x00\x03foo", 2853 b"multiple": b"\x00\x03bar\x00\x03baz\x00\x04quux", 2854 } 2855 2856 objects = {"empty": [], "single": [b"foo"], "multiple": [b"bar", b"baz", b"quux"]} 2857 2858 2859class ListOfIntegersTests(TestCase, ListOfTestsMixin): 2860 """ 2861 Tests for L{ListOf} combined with L{amp.Integer}. 2862 """ 2863 2864 elementType = amp.Integer() 2865 2866 huge = ( 2867 9999999999999999999999999999999999999999999999999999999999 2868 * 9999999999999999999999999999999999999999999999999999999999 2869 ) 2870 2871 strings = { 2872 b"empty": b"", 2873 b"single": b"\x00\x0210", 2874 b"multiple": b"\x00\x011\x00\x0220\x00\x03500", 2875 b"huge": b"\x00\x74%d" % (huge,), 2876 b"negative": b"\x00\x02-1", 2877 } 2878 2879 objects = { 2880 "empty": [], 2881 "single": [10], 2882 "multiple": [1, 20, 500], 2883 "huge": [huge], 2884 "negative": [-1], 2885 } 2886 2887 2888class ListOfUnicodeTests(TestCase, ListOfTestsMixin): 2889 """ 2890 Tests for L{ListOf} combined with L{amp.Unicode}. 2891 """ 2892 2893 elementType = amp.Unicode() 2894 2895 strings = { 2896 b"empty": b"", 2897 b"single": b"\x00\x03foo", 2898 b"multiple": b"\x00\x03\xe2\x98\x83\x00\x05Hello\x00\x05world", 2899 } 2900 2901 objects = { 2902 "empty": [], 2903 "single": ["foo"], 2904 "multiple": ["\N{SNOWMAN}", "Hello", "world"], 2905 } 2906 2907 2908class ListOfDecimalTests(TestCase, ListOfTestsMixin): 2909 """ 2910 Tests for L{ListOf} combined with L{amp.Decimal}. 2911 """ 2912 2913 elementType = amp.Decimal() 2914 2915 strings = { 2916 b"empty": b"", 2917 b"single": b"\x00\x031.1", 2918 b"extreme": b"\x00\x08Infinity\x00\x09-Infinity", 2919 b"scientist": b"\x00\x083.141E+5\x00\x0a0.00003141\x00\x083.141E-7" 2920 b"\x00\x09-3.141E+5\x00\x0b-0.00003141\x00\x09-3.141E-7", 2921 b"engineer": ( 2922 b"\x00\x04" 2923 + decimal.Decimal("0e6").to_eng_string().encode("ascii") 2924 + b"\x00\x06" 2925 + decimal.Decimal("1.5E-9").to_eng_string().encode("ascii") 2926 ), 2927 } 2928 2929 objects = { 2930 "empty": [], 2931 "single": [decimal.Decimal("1.1")], 2932 "extreme": [ 2933 decimal.Decimal("Infinity"), 2934 decimal.Decimal("-Infinity"), 2935 ], 2936 # exarkun objected to AMP supporting engineering notation because 2937 # it was redundant, until we realised that 1E6 has less precision 2938 # than 1000000 and is represented differently. But they compare 2939 # and even hash equally. There were tears. 2940 "scientist": [ 2941 decimal.Decimal("3.141E5"), 2942 decimal.Decimal("3.141e-5"), 2943 decimal.Decimal("3.141E-7"), 2944 decimal.Decimal("-3.141e5"), 2945 decimal.Decimal("-3.141E-5"), 2946 decimal.Decimal("-3.141e-7"), 2947 ], 2948 "engineer": [ 2949 decimal.Decimal("0e6"), 2950 decimal.Decimal("1.5E-9"), 2951 ], 2952 } 2953 2954 2955class ListOfDecimalNanTests(TestCase, ListOfTestsMixin): 2956 """ 2957 Tests for L{ListOf} combined with L{amp.Decimal} for not-a-number values. 2958 """ 2959 2960 elementType = amp.Decimal() 2961 2962 strings = { 2963 b"nan": b"\x00\x03NaN\x00\x04-NaN\x00\x04sNaN\x00\x05-sNaN", 2964 } 2965 2966 objects = { 2967 "nan": [ 2968 decimal.Decimal("NaN"), 2969 decimal.Decimal("-NaN"), 2970 decimal.Decimal("sNaN"), 2971 decimal.Decimal("-sNaN"), 2972 ] 2973 } 2974 2975 def test_fromBox(self): 2976 """ 2977 L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}. 2978 """ 2979 # Helpers. Decimal.is_{qnan,snan,signed}() are new in 2.6 (or 2.5.2, 2980 # but who's counting). 2981 def is_qnan(decimal): 2982 return "NaN" in str(decimal) and "sNaN" not in str(decimal) 2983 2984 def is_snan(decimal): 2985 return "sNaN" in str(decimal) 2986 2987 def is_signed(decimal): 2988 return "-" in str(decimal) 2989 2990 # NaN values have unusual equality semantics, so this method is 2991 # overridden to compare the resulting objects in a way which works with 2992 # NaNs. 2993 stringList = amp.ListOf(self.elementType) 2994 objects = {} 2995 for key in self.strings: 2996 stringList.fromBox(key, self.strings.copy(), objects, None) 2997 n = objects["nan"] 2998 self.assertTrue(is_qnan(n[0]) and not is_signed(n[0])) 2999 self.assertTrue(is_qnan(n[1]) and is_signed(n[1])) 3000 self.assertTrue(is_snan(n[2]) and not is_signed(n[2])) 3001 self.assertTrue(is_snan(n[3]) and is_signed(n[3])) 3002 3003 3004class DecimalTests(TestCase): 3005 """ 3006 Tests for L{amp.Decimal}. 3007 """ 3008 3009 def test_nonDecimal(self): 3010 """ 3011 L{amp.Decimal.toString} raises L{ValueError} if passed an object which 3012 is not an instance of C{decimal.Decimal}. 3013 """ 3014 argument = amp.Decimal() 3015 self.assertRaises(ValueError, argument.toString, "1.234") 3016 self.assertRaises(ValueError, argument.toString, 1.234) 3017 self.assertRaises(ValueError, argument.toString, 1234) 3018 3019 3020class FloatTests(TestCase): 3021 """ 3022 Tests for L{amp.Float}. 3023 """ 3024 3025 def test_nonFloat(self): 3026 """ 3027 L{amp.Float.toString} raises L{ValueError} if passed an object which 3028 is not a L{float}. 3029 """ 3030 argument = amp.Float() 3031 self.assertRaises(ValueError, argument.toString, "1.234") 3032 self.assertRaises(ValueError, argument.toString, b"1.234") 3033 self.assertRaises(ValueError, argument.toString, 1234) 3034 3035 def test_float(self): 3036 """ 3037 L{amp.Float.toString} returns a bytestring when it is given a L{float}. 3038 """ 3039 argument = amp.Float() 3040 self.assertEqual(argument.toString(1.234), b"1.234") 3041 3042 3043class ListOfDateTimeTests(TestCase, ListOfTestsMixin): 3044 """ 3045 Tests for L{ListOf} combined with L{amp.DateTime}. 3046 """ 3047 3048 elementType = amp.DateTime() 3049 3050 strings = { 3051 b"christmas": b"\x00\x202010-12-25T00:00:00.000000-00:00" 3052 b"\x00\x202010-12-25T00:00:00.000000-00:00", 3053 b"christmas in eu": b"\x00\x202010-12-25T00:00:00.000000+01:00", 3054 b"christmas in iran": b"\x00\x202010-12-25T00:00:00.000000+03:30", 3055 b"christmas in nyc": b"\x00\x202010-12-25T00:00:00.000000-05:00", 3056 b"previous tests": b"\x00\x202010-12-25T00:00:00.000000+03:19" 3057 b"\x00\x202010-12-25T00:00:00.000000-06:59", 3058 } 3059 3060 objects = { 3061 "christmas": [ 3062 datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=amp.utc), 3063 datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 0, 0)), 3064 ], 3065 "christmas in eu": [ 3066 datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 1, 0)), 3067 ], 3068 "christmas in iran": [ 3069 datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 3, 30)), 3070 ], 3071 "christmas in nyc": [ 3072 datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("-", 5, 0)), 3073 ], 3074 "previous tests": [ 3075 datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 3, 19)), 3076 datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("-", 6, 59)), 3077 ], 3078 } 3079 3080 3081class ListOfOptionalTests(TestCase): 3082 """ 3083 Tests to ensure L{ListOf} AMP arguments can be omitted from AMP commands 3084 via the 'optional' flag. 3085 """ 3086 3087 def test_requiredArgumentWithNoneValueRaisesTypeError(self): 3088 """ 3089 L{ListOf.toBox} raises C{TypeError} when passed a value of L{None} 3090 for the argument. 3091 """ 3092 stringList = amp.ListOf(amp.Integer()) 3093 self.assertRaises( 3094 TypeError, 3095 stringList.toBox, 3096 b"omitted", 3097 amp.AmpBox(), 3098 {"omitted": None}, 3099 None, 3100 ) 3101 3102 def test_optionalArgumentWithNoneValueOmitted(self): 3103 """ 3104 L{ListOf.toBox} silently omits serializing any argument with a 3105 value of L{None} that is designated as optional for the protocol. 3106 """ 3107 stringList = amp.ListOf(amp.Integer(), optional=True) 3108 strings = amp.AmpBox() 3109 stringList.toBox(b"omitted", strings, {b"omitted": None}, None) 3110 self.assertEqual(strings, {}) 3111 3112 def test_requiredArgumentWithKeyMissingRaisesKeyError(self): 3113 """ 3114 L{ListOf.toBox} raises C{KeyError} if the argument's key is not 3115 present in the objects dictionary. 3116 """ 3117 stringList = amp.ListOf(amp.Integer()) 3118 self.assertRaises( 3119 KeyError, 3120 stringList.toBox, 3121 b"ommited", 3122 amp.AmpBox(), 3123 {"someOtherKey": 0}, 3124 None, 3125 ) 3126 3127 def test_optionalArgumentWithKeyMissingOmitted(self): 3128 """ 3129 L{ListOf.toBox} silently omits serializing any argument designated 3130 as optional whose key is not present in the objects dictionary. 3131 """ 3132 stringList = amp.ListOf(amp.Integer(), optional=True) 3133 stringList.toBox(b"ommited", amp.AmpBox(), {b"someOtherKey": 0}, None) 3134 3135 def test_omittedOptionalArgumentDeserializesAsNone(self): 3136 """ 3137 L{ListOf.fromBox} correctly reverses the operation performed by 3138 L{ListOf.toBox} for optional arguments. 3139 """ 3140 stringList = amp.ListOf(amp.Integer(), optional=True) 3141 objects = {} 3142 stringList.fromBox(b"omitted", {}, objects, None) 3143 self.assertEqual(objects, {"omitted": None}) 3144 3145 3146@implementer(interfaces.IUNIXTransport) 3147class UNIXStringTransport: 3148 """ 3149 An in-memory implementation of L{interfaces.IUNIXTransport} which collects 3150 all data given to it for later inspection. 3151 3152 @ivar _queue: A C{list} of the data which has been given to this transport, 3153 eg via C{write} or C{sendFileDescriptor}. Elements are two-tuples of a 3154 string (identifying the destination of the data) and the data itself. 3155 """ 3156 3157 def __init__(self, descriptorFuzz): 3158 """ 3159 @param descriptorFuzz: An offset to apply to descriptors. 3160 @type descriptorFuzz: C{int} 3161 """ 3162 self._fuzz = descriptorFuzz 3163 self._queue = [] 3164 3165 def sendFileDescriptor(self, descriptor): 3166 self._queue.append(("fileDescriptorReceived", descriptor + self._fuzz)) 3167 3168 def write(self, data): 3169 self._queue.append(("dataReceived", data)) 3170 3171 def writeSequence(self, seq): 3172 for data in seq: 3173 self.write(data) 3174 3175 def loseConnection(self): 3176 self._queue.append(("connectionLost", Failure(error.ConnectionLost()))) 3177 3178 def getHost(self): 3179 return address.UNIXAddress("/tmp/some-path") 3180 3181 def getPeer(self): 3182 return address.UNIXAddress("/tmp/another-path") 3183 3184 3185# Minimal evidence that we got the signatures right 3186verifyClass(interfaces.ITransport, UNIXStringTransport) 3187verifyClass(interfaces.IUNIXTransport, UNIXStringTransport) 3188 3189 3190class DescriptorTests(TestCase): 3191 """ 3192 Tests for L{amp.Descriptor}, an argument type for passing a file descriptor 3193 over an AMP connection over a UNIX domain socket. 3194 """ 3195 3196 def setUp(self): 3197 self.fuzz = 3 3198 self.transport = UNIXStringTransport(descriptorFuzz=self.fuzz) 3199 self.protocol = amp.BinaryBoxProtocol(amp.BoxDispatcher(amp.CommandLocator())) 3200 self.protocol.makeConnection(self.transport) 3201 3202 def test_fromStringProto(self): 3203 """ 3204 L{Descriptor.fromStringProto} constructs a file descriptor value by 3205 extracting a previously received file descriptor corresponding to the 3206 wire value of the argument from the L{_DescriptorExchanger} state of the 3207 protocol passed to it. 3208 3209 This is a whitebox test which involves direct L{_DescriptorExchanger} 3210 state inspection. 3211 """ 3212 argument = amp.Descriptor() 3213 self.protocol.fileDescriptorReceived(5) 3214 self.protocol.fileDescriptorReceived(3) 3215 self.protocol.fileDescriptorReceived(1) 3216 self.assertEqual(5, argument.fromStringProto("0", self.protocol)) 3217 self.assertEqual(3, argument.fromStringProto("1", self.protocol)) 3218 self.assertEqual(1, argument.fromStringProto("2", self.protocol)) 3219 self.assertEqual({}, self.protocol._descriptors) 3220 3221 def test_toStringProto(self): 3222 """ 3223 To send a file descriptor, L{Descriptor.toStringProto} uses the 3224 L{IUNIXTransport.sendFileDescriptor} implementation of the transport of 3225 the protocol passed to it to copy the file descriptor. Each subsequent 3226 descriptor sent over a particular AMP connection is assigned the next 3227 integer value, starting from 0. The base ten string representation of 3228 this value is the byte encoding of the argument. 3229 3230 This is a whitebox test which involves direct L{_DescriptorExchanger} 3231 state inspection and mutation. 3232 """ 3233 argument = amp.Descriptor() 3234 self.assertEqual(b"0", argument.toStringProto(2, self.protocol)) 3235 self.assertEqual( 3236 ("fileDescriptorReceived", 2 + self.fuzz), self.transport._queue.pop(0) 3237 ) 3238 self.assertEqual(b"1", argument.toStringProto(4, self.protocol)) 3239 self.assertEqual( 3240 ("fileDescriptorReceived", 4 + self.fuzz), self.transport._queue.pop(0) 3241 ) 3242 self.assertEqual(b"2", argument.toStringProto(6, self.protocol)) 3243 self.assertEqual( 3244 ("fileDescriptorReceived", 6 + self.fuzz), self.transport._queue.pop(0) 3245 ) 3246 self.assertEqual({}, self.protocol._descriptors) 3247 3248 def test_roundTrip(self): 3249 """ 3250 L{amp.Descriptor.fromBox} can interpret an L{amp.AmpBox} constructed by 3251 L{amp.Descriptor.toBox} to reconstruct a file descriptor value. 3252 """ 3253 name = "alpha" 3254 nameAsBytes = name.encode("ascii") 3255 strings = {} 3256 descriptor = 17 3257 sendObjects = {name: descriptor} 3258 3259 argument = amp.Descriptor() 3260 argument.toBox(nameAsBytes, strings, sendObjects.copy(), self.protocol) 3261 3262 receiver = amp.BinaryBoxProtocol(amp.BoxDispatcher(amp.CommandLocator())) 3263 for event in self.transport._queue: 3264 getattr(receiver, event[0])(*event[1:]) 3265 3266 receiveObjects = {} 3267 argument.fromBox(nameAsBytes, strings.copy(), receiveObjects, receiver) 3268 3269 # Make sure we got the descriptor. Adjust by fuzz to be more convincing 3270 # of having gone through L{IUNIXTransport.sendFileDescriptor}, not just 3271 # converted to a string and then parsed back into an integer. 3272 self.assertEqual(descriptor + self.fuzz, receiveObjects[name]) 3273 3274 3275class DateTimeTests(TestCase): 3276 """ 3277 Tests for L{amp.DateTime}, L{amp._FixedOffsetTZInfo}, and L{amp.utc}. 3278 """ 3279 3280 string = b"9876-01-23T12:34:56.054321-01:23" 3281 tzinfo = tz("-", 1, 23) 3282 object = datetime.datetime(9876, 1, 23, 12, 34, 56, 54321, tzinfo) 3283 3284 def test_invalidString(self): 3285 """ 3286 L{amp.DateTime.fromString} raises L{ValueError} when passed a string 3287 which does not represent a timestamp in the proper format. 3288 """ 3289 d = amp.DateTime() 3290 self.assertRaises(ValueError, d.fromString, "abc") 3291 3292 def test_invalidDatetime(self): 3293 """ 3294 L{amp.DateTime.toString} raises L{ValueError} when passed a naive 3295 datetime (a datetime with no timezone information). 3296 """ 3297 d = amp.DateTime() 3298 self.assertRaises( 3299 ValueError, d.toString, datetime.datetime(2010, 12, 25, 0, 0, 0) 3300 ) 3301 3302 def test_fromString(self): 3303 """ 3304 L{amp.DateTime.fromString} returns a C{datetime.datetime} with all of 3305 its fields populated from the string passed to it. 3306 """ 3307 argument = amp.DateTime() 3308 value = argument.fromString(self.string) 3309 self.assertEqual(value, self.object) 3310 3311 def test_toString(self): 3312 """ 3313 L{amp.DateTime.toString} returns a C{str} in the wire format including 3314 all of the information from the C{datetime.datetime} passed into it, 3315 including the timezone offset. 3316 """ 3317 argument = amp.DateTime() 3318 value = argument.toString(self.object) 3319 self.assertEqual(value, self.string) 3320 3321 3322class UTCTests(TestCase): 3323 """ 3324 Tests for L{amp.utc}. 3325 """ 3326 3327 def test_tzname(self): 3328 """ 3329 L{amp.utc.tzname} returns C{"+00:00"}. 3330 """ 3331 self.assertEqual(amp.utc.tzname(None), "+00:00") 3332 3333 def test_dst(self): 3334 """ 3335 L{amp.utc.dst} returns a zero timedelta. 3336 """ 3337 self.assertEqual(amp.utc.dst(None), datetime.timedelta(0)) 3338 3339 def test_utcoffset(self): 3340 """ 3341 L{amp.utc.utcoffset} returns a zero timedelta. 3342 """ 3343 self.assertEqual(amp.utc.utcoffset(None), datetime.timedelta(0)) 3344 3345 def test_badSign(self): 3346 """ 3347 L{amp._FixedOffsetTZInfo.fromSignHoursMinutes} raises L{ValueError} if 3348 passed an offset sign other than C{'+'} or C{'-'}. 3349 """ 3350 self.assertRaises(ValueError, tz, "?", 0, 0) 3351 3352 3353class RemoteAmpErrorTests(TestCase): 3354 """ 3355 Tests for L{amp.RemoteAmpError}. 3356 """ 3357 3358 def test_stringMessage(self): 3359 """ 3360 L{amp.RemoteAmpError} renders the given C{errorCode} (C{bytes}) and 3361 C{description} into a native string. 3362 """ 3363 error = amp.RemoteAmpError(b"BROKEN", "Something has broken") 3364 self.assertEqual("Code<BROKEN>: Something has broken", str(error)) 3365 3366 def test_stringMessageReplacesNonAsciiText(self): 3367 """ 3368 When C{errorCode} contains non-ASCII characters, L{amp.RemoteAmpError} 3369 renders then as backslash-escape sequences. 3370 """ 3371 error = amp.RemoteAmpError(b"BROKEN-\xff", "Something has broken") 3372 self.assertEqual("Code<BROKEN-\\xff>: Something has broken", str(error)) 3373 3374 def test_stringMessageWithLocalFailure(self): 3375 """ 3376 L{amp.RemoteAmpError} renders local errors with a "(local)" marker and 3377 a brief traceback. 3378 """ 3379 failure = Failure(Exception("Something came loose")) 3380 error = amp.RemoteAmpError(b"BROKEN", "Something has broken", local=failure) 3381 self.assertRegex( 3382 str(error), 3383 ( 3384 "^Code<BROKEN> [(]local[)]: Something has broken\n" 3385 "Traceback [(]failure with no frames[)]: " 3386 "<.+Exception.>: Something came loose\n" 3387 ), 3388 ) 3389