1# Copyright (c) 2005 Divmod, Inc.
2# Copyright (c) Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5"""
6Tests for L{twisted.protocols.amp}.
7"""
8
9
10import datetime
11import decimal
12from typing import Dict, Type
13from unittest import skipIf
14
15from zope.interface import implementer
16from zope.interface.verify import verifyClass, verifyObject
17
18from twisted.internet import address, defer, error, interfaces, protocol, reactor
19from twisted.protocols import amp
20from twisted.python import filepath
21from twisted.python.failure import Failure
22from twisted.test import iosim
23from twisted.test.proto_helpers import StringTransport
24from twisted.trial.unittest import TestCase
25
26try:
27    from twisted.internet import ssl as _ssl
28except ImportError:
29    ssl = None
30else:
31    if not _ssl.supported:
32        ssl = None
33    else:
34        ssl = _ssl
35
36if ssl is None:
37    skipSSL = True
38else:
39    skipSSL = False
40
41if not interfaces.IReactorSSL.providedBy(reactor):
42    reactorLacksSSL = True
43else:
44    reactorLacksSSL = False
45
46
47tz = amp._FixedOffsetTZInfo.fromSignHoursMinutes
48
49
50class TestProto(protocol.Protocol):
51    """
52    A trivial protocol for use in testing where a L{Protocol} is expected.
53
54    @ivar instanceId: the id of this instance
55    @ivar onConnLost: deferred that will fired when the connection is lost
56    @ivar dataToSend: data to send on the protocol
57    """
58
59    instanceCount = 0
60
61    def __init__(self, onConnLost, dataToSend):
62        assert isinstance(dataToSend, bytes), repr(dataToSend)
63        self.onConnLost = onConnLost
64        self.dataToSend = dataToSend
65        self.instanceId = TestProto.instanceCount
66        TestProto.instanceCount = TestProto.instanceCount + 1
67
68    def connectionMade(self):
69        self.data = []
70        self.transport.write(self.dataToSend)
71
72    def dataReceived(self, bytes):
73        self.data.append(bytes)
74
75    def connectionLost(self, reason):
76        self.onConnLost.callback(self.data)
77
78    def __repr__(self) -> str:
79        """
80        Custom repr for testing to avoid coupling amp tests with repr from
81        L{Protocol}
82
83        Returns a string which contains a unique identifier that can be looked
84        up using the instanceId property::
85
86            <TestProto #3>
87        """
88        return "<TestProto #%d>" % (self.instanceId,)
89
90
91class SimpleSymmetricProtocol(amp.AMP):
92    def sendHello(self, text):
93        return self.callRemoteString(b"hello", hello=text)
94
95    def amp_HELLO(self, box):
96        return amp.Box(hello=box[b"hello"])
97
98
99class UnfriendlyGreeting(Exception):
100    """Greeting was insufficiently kind."""
101
102
103class DeathThreat(Exception):
104    """Greeting was insufficiently kind."""
105
106
107class UnknownProtocol(Exception):
108    """Asked to switch to the wrong protocol."""
109
110
111class TransportPeer(amp.Argument):
112    # this serves as some informal documentation for how to get variables from
113    # the protocol or your environment and pass them to methods as arguments.
114    def retrieve(self, d, name, proto):
115        return b""
116
117    def fromStringProto(self, notAString, proto):
118        return proto.transport.getPeer()
119
120    def toBox(self, name, strings, objects, proto):
121        return
122
123
124class Hello(amp.Command):
125
126    commandName = b"hello"
127
128    arguments = [
129        (b"hello", amp.String()),
130        (b"optional", amp.Boolean(optional=True)),
131        (b"print", amp.Unicode(optional=True)),
132        (b"from", TransportPeer(optional=True)),
133        (b"mixedCase", amp.String(optional=True)),
134        (b"dash-arg", amp.String(optional=True)),
135        (b"underscore_arg", amp.String(optional=True)),
136    ]
137
138    response = [(b"hello", amp.String()), (b"print", amp.Unicode(optional=True))]
139
140    errors: Dict[Type[Exception], bytes] = {UnfriendlyGreeting: b"UNFRIENDLY"}
141
142    fatalErrors: Dict[Type[Exception], bytes] = {DeathThreat: b"DEAD"}
143
144
145class NoAnswerHello(Hello):
146    commandName = Hello.commandName
147    requiresAnswer = False
148
149
150class FutureHello(amp.Command):
151    commandName = b"hello"
152
153    arguments = [
154        (b"hello", amp.String()),
155        (b"optional", amp.Boolean(optional=True)),
156        (b"print", amp.Unicode(optional=True)),
157        (b"from", TransportPeer(optional=True)),
158        (b"bonus", amp.String(optional=True)),  # addt'l arguments
159        # should generally be
160        # added at the end, and
161        # be optional...
162    ]
163
164    response = [(b"hello", amp.String()), (b"print", amp.Unicode(optional=True))]
165
166    errors = {UnfriendlyGreeting: b"UNFRIENDLY"}
167
168
169class WTF(amp.Command):
170    """
171    An example of an invalid command.
172    """
173
174
175class BrokenReturn(amp.Command):
176    """An example of a perfectly good command, but the handler is going to return
177    None...
178    """
179
180    commandName = b"broken_return"
181
182
183class Goodbye(amp.Command):
184    # commandName left blank on purpose: this tests implicit command names.
185    response = [(b"goodbye", amp.String())]
186    responseType = amp.QuitBox
187
188
189class WaitForever(amp.Command):
190    commandName = b"wait_forever"
191
192
193class GetList(amp.Command):
194    commandName = b"getlist"
195    arguments = [(b"length", amp.Integer())]
196    response = [(b"body", amp.AmpList([(b"x", amp.Integer())]))]
197
198
199class DontRejectMe(amp.Command):
200    commandName = b"dontrejectme"
201    arguments = [
202        (b"magicWord", amp.Unicode()),
203        (b"list", amp.AmpList([(b"name", amp.Unicode())], optional=True)),
204    ]
205    response = [(b"response", amp.Unicode())]
206
207
208class SecuredPing(amp.Command):
209    # XXX TODO: actually make this refuse to send over an insecure connection
210    response = [(b"pinged", amp.Boolean())]
211
212
213class TestSwitchProto(amp.ProtocolSwitchCommand):
214    commandName = b"Switch-Proto"
215
216    arguments = [
217        (b"name", amp.String()),
218    ]
219    errors = {UnknownProtocol: b"UNKNOWN"}
220
221
222class SingleUseFactory(protocol.ClientFactory):
223    def __init__(self, proto):
224        self.proto = proto
225        self.proto.factory = self
226
227    def buildProtocol(self, addr):
228        p, self.proto = self.proto, None
229        return p
230
231    reasonFailed = None
232
233    def clientConnectionFailed(self, connector, reason):
234        self.reasonFailed = reason
235        return
236
237
238THING_I_DONT_UNDERSTAND = b"gwebol nargo"
239
240
241class ThingIDontUnderstandError(Exception):
242    pass
243
244
245class FactoryNotifier(amp.AMP):
246    factory = None
247
248    def connectionMade(self):
249        if self.factory is not None:
250            self.factory.theProto = self
251            if hasattr(self.factory, "onMade"):
252                self.factory.onMade.callback(None)
253
254    def emitpong(self):
255        from twisted.internet.interfaces import ISSLTransport
256
257        if not ISSLTransport.providedBy(self.transport):
258            raise DeathThreat("only send secure pings over secure channels")
259        return {"pinged": True}
260
261    SecuredPing.responder(emitpong)
262
263
264class SimpleSymmetricCommandProtocol(FactoryNotifier):
265    maybeLater = None
266
267    def __init__(self, onConnLost=None):
268        amp.AMP.__init__(self)
269        self.onConnLost = onConnLost
270
271    def sendHello(self, text):
272        return self.callRemote(Hello, hello=text)
273
274    def sendUnicodeHello(self, text, translation):
275        return self.callRemote(Hello, hello=text, Print=translation)
276
277    greeted = False
278
279    def cmdHello(
280        self,
281        hello,
282        From,
283        optional=None,
284        Print=None,
285        mixedCase=None,
286        dash_arg=None,
287        underscore_arg=None,
288    ):
289        assert From == self.transport.getPeer()
290        if hello == THING_I_DONT_UNDERSTAND:
291            raise ThingIDontUnderstandError()
292        if hello.startswith(b"fuck"):
293            raise UnfriendlyGreeting("Don't be a dick.")
294        if hello == b"die":
295            raise DeathThreat("aieeeeeeeee")
296        result = dict(hello=hello)
297        if Print is not None:
298            result.update(dict(Print=Print))
299        self.greeted = True
300        return result
301
302    Hello.responder(cmdHello)
303
304    def cmdGetlist(self, length):
305        return {"body": [dict(x=1)] * length}
306
307    GetList.responder(cmdGetlist)
308
309    def okiwont(self, magicWord, list=None):
310        if list is None:
311            response = "list omitted"
312        else:
313            response = "%s accepted" % (list[0]["name"])
314        return dict(response=response)
315
316    DontRejectMe.responder(okiwont)
317
318    def waitforit(self):
319        self.waiting = defer.Deferred()
320        return self.waiting
321
322    WaitForever.responder(waitforit)
323
324    def saybye(self):
325        return dict(goodbye=b"everyone")
326
327    Goodbye.responder(saybye)
328
329    def switchToTestProtocol(self, fail=False):
330        if fail:
331            name = b"no-proto"
332        else:
333            name = b"test-proto"
334        p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA)
335        return self.callRemote(
336            TestSwitchProto, SingleUseFactory(p), name=name
337        ).addCallback(lambda ign: p)
338
339    def switchit(self, name):
340        if name == b"test-proto":
341            return TestProto(self.onConnLost, SWITCH_SERVER_DATA)
342        raise UnknownProtocol(name)
343
344    TestSwitchProto.responder(switchit)
345
346    def donothing(self):
347        return None
348
349    BrokenReturn.responder(donothing)
350
351
352class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol):
353    def switchit(self, name):
354        if name == b"test-proto":
355            self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA)
356            self.maybeLater = defer.Deferred()
357            return self.maybeLater
358
359    TestSwitchProto.responder(switchit)
360
361
362class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
363    def badResponder(
364        self,
365        hello,
366        From,
367        optional=None,
368        Print=None,
369        mixedCase=None,
370        dash_arg=None,
371        underscore_arg=None,
372    ):
373        """
374        This responder does nothing and forgets to return a dictionary.
375        """
376
377    NoAnswerHello.responder(badResponder)
378
379
380class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
381    def goodNoAnswerResponder(
382        self,
383        hello,
384        From,
385        optional=None,
386        Print=None,
387        mixedCase=None,
388        dash_arg=None,
389        underscore_arg=None,
390    ):
391        return dict(hello=hello + b"-noanswer")
392
393    NoAnswerHello.responder(goodNoAnswerResponder)
394
395
396def connectedServerAndClient(
397    ServerClass=SimpleSymmetricProtocol, ClientClass=SimpleSymmetricProtocol, *a, **kw
398):
399    """Returns a 3-tuple: (client, server, pump)"""
400    return iosim.connectedServerAndClient(ServerClass, ClientClass, *a, **kw)
401
402
403class TotallyDumbProtocol(protocol.Protocol):
404    buf = b""
405
406    def dataReceived(self, data):
407        self.buf += data
408
409
410class LiteralAmp(amp.AMP):
411    def __init__(self):
412        self.boxes = []
413
414    def ampBoxReceived(self, box):
415        self.boxes.append(box)
416        return
417
418
419class AmpBoxTests(TestCase):
420    """
421    Test a few essential properties of AMP boxes, mostly with respect to
422    serialization correctness.
423    """
424
425    def test_serializeStr(self):
426        """
427        Make sure that strs serialize to strs.
428        """
429        a = amp.AmpBox(key=b"value")
430        self.assertEqual(type(a.serialize()), bytes)
431
432    def test_serializeUnicodeKeyRaises(self):
433        """
434        Verify that TypeError is raised when trying to serialize Unicode keys.
435        """
436        a = amp.AmpBox(**{"key": "value"})
437        self.assertRaises(TypeError, a.serialize)
438
439    def test_serializeUnicodeValueRaises(self):
440        """
441        Verify that TypeError is raised when trying to serialize Unicode
442        values.
443        """
444        a = amp.AmpBox(key="value")
445        self.assertRaises(TypeError, a.serialize)
446
447
448class ParsingTests(TestCase):
449    def test_booleanValues(self):
450        """
451        Verify that the Boolean parser parses 'True' and 'False', but nothing
452        else.
453        """
454        b = amp.Boolean()
455        self.assertTrue(b.fromString(b"True"))
456        self.assertFalse(b.fromString(b"False"))
457        self.assertRaises(TypeError, b.fromString, b"ninja")
458        self.assertRaises(TypeError, b.fromString, b"true")
459        self.assertRaises(TypeError, b.fromString, b"TRUE")
460        self.assertEqual(b.toString(True), b"True")
461        self.assertEqual(b.toString(False), b"False")
462
463    def test_pathValueRoundTrip(self):
464        """
465        Verify the 'Path' argument can parse and emit a file path.
466        """
467        fp = filepath.FilePath(self.mktemp())
468        p = amp.Path()
469        s = p.toString(fp)
470        v = p.fromString(s)
471        self.assertIsNot(fp, v)  # sanity check
472        self.assertEqual(fp, v)
473
474    def test_sillyEmptyThing(self):
475        """
476        Test that empty boxes raise an error; they aren't supposed to be sent
477        on purpose.
478        """
479        a = amp.AMP()
480        return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box())
481
482    def test_ParsingRoundTrip(self):
483        """
484        Verify that various kinds of data make it through the encode/parse
485        round-trip unharmed.
486        """
487        c, s, p = connectedServerAndClient(
488            ClientClass=LiteralAmp, ServerClass=LiteralAmp
489        )
490
491        SIMPLE = (b"simple", b"test")
492        CE = (b"ceq", b": ")
493        CR = (b"crtest", b"test\r")
494        LF = (b"lftest", b"hello\n")
495        NEWLINE = (b"newline", b"test\r\none\r\ntwo")
496        NEWLINE2 = (b"newline2", b"test\r\none\r\n two")
497        BODYTEST = (b"body", b"blah\r\n\r\ntesttest")
498
499        testData = [
500            [SIMPLE],
501            [SIMPLE, BODYTEST],
502            [SIMPLE, CE],
503            [SIMPLE, CR],
504            [SIMPLE, CE, CR, LF],
505            [CE, CR, LF],
506            [SIMPLE, NEWLINE, CE, NEWLINE2],
507            [BODYTEST, SIMPLE, NEWLINE],
508        ]
509
510        for test in testData:
511            jb = amp.Box()
512            jb.update(dict(test))
513            jb._sendTo(c)
514            p.flush()
515            self.assertEqual(s.boxes[-1], jb)
516
517
518class FakeLocator:
519    """
520    This is a fake implementation of the interface implied by
521    L{CommandLocator}.
522    """
523
524    def __init__(self):
525        """
526        Remember the given keyword arguments as a set of responders.
527        """
528        self.commands = {}
529
530    def locateResponder(self, commandName):
531        """
532        Look up and return a function passed as a keyword argument of the given
533        name to the constructor.
534        """
535        return self.commands[commandName]
536
537
538class FakeSender:
539    """
540    This is a fake implementation of the 'box sender' interface implied by
541    L{AMP}.
542    """
543
544    def __init__(self):
545        """
546        Create a fake sender and initialize the list of received boxes and
547        unhandled errors.
548        """
549        self.sentBoxes = []
550        self.unhandledErrors = []
551        self.expectedErrors = 0
552
553    def expectError(self):
554        """
555        Expect one error, so that the test doesn't fail.
556        """
557        self.expectedErrors += 1
558
559    def sendBox(self, box):
560        """
561        Accept a box, but don't do anything.
562        """
563        self.sentBoxes.append(box)
564
565    def unhandledError(self, failure):
566        """
567        Deal with failures by instantly re-raising them for easier debugging.
568        """
569        self.expectedErrors -= 1
570        if self.expectedErrors < 0:
571            failure.raiseException()
572        else:
573            self.unhandledErrors.append(failure)
574
575
576class CommandDispatchTests(TestCase):
577    """
578    The AMP CommandDispatcher class dispatches converts AMP boxes into commands
579    and responses using Command.responder decorator.
580
581    Note: Originally, AMP's factoring was such that many tests for this
582    functionality are now implemented as full round-trip tests in L{AMPTests}.
583    Future tests should be written at this level instead, to ensure API
584    compatibility and to provide more granular, readable units of test
585    coverage.
586    """
587
588    def setUp(self):
589        """
590        Create a dispatcher to use.
591        """
592        self.locator = FakeLocator()
593        self.sender = FakeSender()
594        self.dispatcher = amp.BoxDispatcher(self.locator)
595        self.dispatcher.startReceivingBoxes(self.sender)
596
597    def test_receivedAsk(self):
598        """
599        L{CommandDispatcher.ampBoxReceived} should locate the appropriate
600        command in its responder lookup, based on the '_ask' key.
601        """
602        received = []
603
604        def thunk(box):
605            received.append(box)
606            return amp.Box({"hello": "goodbye"})
607
608        input = amp.Box(_command="hello", _ask="test-command-id", hello="world")
609        self.locator.commands["hello"] = thunk
610        self.dispatcher.ampBoxReceived(input)
611        self.assertEqual(received, [input])
612
613    def test_sendUnhandledError(self):
614        """
615        L{CommandDispatcher} should relay its unhandled errors in responding to
616        boxes to its boxSender.
617        """
618        err = RuntimeError("something went wrong, oh no")
619        self.sender.expectError()
620        self.dispatcher.unhandledError(Failure(err))
621        self.assertEqual(len(self.sender.unhandledErrors), 1)
622        self.assertEqual(self.sender.unhandledErrors[0].value, err)
623
624    def test_unhandledSerializationError(self):
625        """
626        Errors during serialization ought to be relayed to the sender's
627        unhandledError method.
628        """
629        err = RuntimeError("something undefined went wrong")
630
631        def thunk(result):
632            class BrokenBox(amp.Box):
633                def _sendTo(self, proto):
634                    raise err
635
636            return BrokenBox()
637
638        self.locator.commands["hello"] = thunk
639        input = amp.Box(_command="hello", _ask="test-command-id", hello="world")
640        self.sender.expectError()
641        self.dispatcher.ampBoxReceived(input)
642        self.assertEqual(len(self.sender.unhandledErrors), 1)
643        self.assertEqual(self.sender.unhandledErrors[0].value, err)
644
645    def test_callRemote(self):
646        """
647        L{CommandDispatcher.callRemote} should emit a properly formatted '_ask'
648        box to its boxSender and record an outstanding L{Deferred}.  When a
649        corresponding '_answer' packet is received, the L{Deferred} should be
650        fired, and the results translated via the given L{Command}'s response
651        de-serialization.
652        """
653        D = self.dispatcher.callRemote(Hello, hello=b"world")
654        self.assertEqual(
655            self.sender.sentBoxes,
656            [amp.AmpBox(_command=b"hello", _ask=b"1", hello=b"world")],
657        )
658        answers = []
659        D.addCallback(answers.append)
660        self.assertEqual(answers, [])
661        self.dispatcher.ampBoxReceived(
662            amp.AmpBox({b"hello": b"yay", b"print": b"ignored", b"_answer": b"1"})
663        )
664        self.assertEqual(answers, [dict(hello=b"yay", Print="ignored")])
665
666    def _localCallbackErrorLoggingTest(self, callResult):
667        """
668        Verify that C{callResult} completes with a L{None} result and that an
669        unhandled error has been logged.
670        """
671        finalResult = []
672        callResult.addBoth(finalResult.append)
673
674        self.assertEqual(1, len(self.sender.unhandledErrors))
675        self.assertIsInstance(self.sender.unhandledErrors[0].value, ZeroDivisionError)
676
677        self.assertEqual([None], finalResult)
678
679    def test_callRemoteSuccessLocalCallbackErrorLogging(self):
680        """
681        If the last callback on the L{Deferred} returned by C{callRemote} (added
682        by application code calling C{callRemote}) fails, the failure is passed
683        to the sender's C{unhandledError} method.
684        """
685        self.sender.expectError()
686
687        callResult = self.dispatcher.callRemote(Hello, hello=b"world")
688        callResult.addCallback(lambda result: 1 // 0)
689
690        self.dispatcher.ampBoxReceived(
691            amp.AmpBox({b"hello": b"yay", b"print": b"ignored", b"_answer": b"1"})
692        )
693
694        self._localCallbackErrorLoggingTest(callResult)
695
696    def test_callRemoteErrorLocalCallbackErrorLogging(self):
697        """
698        Like L{test_callRemoteSuccessLocalCallbackErrorLogging}, but for the
699        case where the L{Deferred} returned by C{callRemote} fails.
700        """
701        self.sender.expectError()
702
703        callResult = self.dispatcher.callRemote(Hello, hello=b"world")
704        callResult.addErrback(lambda result: 1 // 0)
705
706        self.dispatcher.ampBoxReceived(
707            amp.AmpBox(
708                {
709                    b"_error": b"1",
710                    b"_error_code": b"bugs",
711                    b"_error_description": b"stuff",
712                }
713            )
714        )
715
716        self._localCallbackErrorLoggingTest(callResult)
717
718
719class SimpleGreeting(amp.Command):
720    """
721    A very simple greeting command that uses a few basic argument types.
722    """
723
724    commandName = b"simple"
725    arguments = [(b"greeting", amp.Unicode()), (b"cookie", amp.Integer())]
726    response = [(b"cookieplus", amp.Integer())]
727
728
729class TestLocator(amp.CommandLocator):
730    """
731    A locator which implements a responder to the 'simple' command.
732    """
733
734    def __init__(self):
735        self.greetings = []
736
737    def greetingResponder(self, greeting, cookie):
738        self.greetings.append((greeting, cookie))
739        return dict(cookieplus=cookie + 3)
740
741    greetingResponder = SimpleGreeting.responder(greetingResponder)
742
743
744class OverridingLocator(TestLocator):
745    """
746    A locator which overrides the responder to the 'simple' command.
747    """
748
749    def greetingResponder(self, greeting, cookie):
750        """
751        Return a different cookieplus than L{TestLocator.greetingResponder}.
752        """
753        self.greetings.append((greeting, cookie))
754        return dict(cookieplus=cookie + 4)
755
756    greetingResponder = SimpleGreeting.responder(greetingResponder)
757
758
759class InheritingLocator(OverridingLocator):
760    """
761    This locator should inherit the responder from L{OverridingLocator}.
762    """
763
764
765class OverrideLocatorAMP(amp.AMP):
766    def __init__(self):
767        amp.AMP.__init__(self)
768        self.customResponder = object()
769        self.expectations = {b"custom": self.customResponder}
770        self.greetings = []
771
772    def lookupFunction(self, name):
773        """
774        Override the deprecated lookupFunction function.
775        """
776        if name in self.expectations:
777            result = self.expectations[name]
778            return result
779        else:
780            return super().lookupFunction(name)
781
782    def greetingResponder(self, greeting, cookie):
783        self.greetings.append((greeting, cookie))
784        return dict(cookieplus=cookie + 3)
785
786    greetingResponder = SimpleGreeting.responder(greetingResponder)
787
788
789class CommandLocatorTests(TestCase):
790    """
791    The CommandLocator should enable users to specify responders to commands as
792    functions that take structured objects, annotated with metadata.
793    """
794
795    def _checkSimpleGreeting(self, locatorClass, expected):
796        """
797        Check that a locator of type C{locatorClass} finds a responder
798        for command named I{simple} and that the found responder answers
799        with the C{expected} result to a C{SimpleGreeting<"ni hao", 5>}
800        command.
801        """
802        locator = locatorClass()
803        responderCallable = locator.locateResponder(b"simple")
804        result = responderCallable(amp.Box(greeting=b"ni hao", cookie=b"5"))
805
806        def done(values):
807            self.assertEqual(values, amp.AmpBox(cookieplus=b"%d" % (expected,)))
808
809        return result.addCallback(done)
810
811    def test_responderDecorator(self):
812        """
813        A method on a L{CommandLocator} subclass decorated with a L{Command}
814        subclass's L{responder} decorator should be returned from
815        locateResponder, wrapped in logic to serialize and deserialize its
816        arguments.
817        """
818        return self._checkSimpleGreeting(TestLocator, 8)
819
820    def test_responderOverriding(self):
821        """
822        L{CommandLocator} subclasses can override a responder inherited from
823        a base class by using the L{Command.responder} decorator to register
824        a new responder method.
825        """
826        return self._checkSimpleGreeting(OverridingLocator, 9)
827
828    def test_responderInheritance(self):
829        """
830        Responder lookup follows the same rules as normal method lookup
831        rules, particularly with respect to inheritance.
832        """
833        return self._checkSimpleGreeting(InheritingLocator, 9)
834
835    def test_lookupFunctionDeprecatedOverride(self):
836        """
837        Subclasses which override locateResponder under its old name,
838        lookupFunction, should have the override invoked instead.  (This tests
839        an AMP subclass, because in the version of the code that could invoke
840        this deprecated code path, there was no L{CommandLocator}.)
841        """
842        locator = OverrideLocatorAMP()
843        customResponderObject = self.assertWarns(
844            PendingDeprecationWarning,
845            "Override locateResponder, not lookupFunction.",
846            __file__,
847            lambda: locator.locateResponder(b"custom"),
848        )
849        self.assertEqual(locator.customResponder, customResponderObject)
850        # Make sure upcalling works too
851        normalResponderObject = self.assertWarns(
852            PendingDeprecationWarning,
853            "Override locateResponder, not lookupFunction.",
854            __file__,
855            lambda: locator.locateResponder(b"simple"),
856        )
857        result = normalResponderObject(amp.Box(greeting=b"ni hao", cookie=b"5"))
858
859        def done(values):
860            self.assertEqual(values, amp.AmpBox(cookieplus=b"8"))
861
862        return result.addCallback(done)
863
864    def test_lookupFunctionDeprecatedInvoke(self):
865        """
866        Invoking locateResponder under its old name, lookupFunction, should
867        emit a deprecation warning, but do the same thing.
868        """
869        locator = TestLocator()
870        responderCallable = self.assertWarns(
871            PendingDeprecationWarning,
872            "Call locateResponder, not lookupFunction.",
873            __file__,
874            lambda: locator.lookupFunction(b"simple"),
875        )
876        result = responderCallable(amp.Box(greeting=b"ni hao", cookie=b"5"))
877
878        def done(values):
879            self.assertEqual(values, amp.AmpBox(cookieplus=b"8"))
880
881        return result.addCallback(done)
882
883
884SWITCH_CLIENT_DATA = b"Success!"
885SWITCH_SERVER_DATA = b"No, really.  Success."
886
887
888class BinaryProtocolTests(TestCase):
889    """
890    Tests for L{amp.BinaryBoxProtocol}.
891
892    @ivar _boxSender: After C{startReceivingBoxes} is called, the L{IBoxSender}
893        which was passed to it.
894    """
895
896    def setUp(self):
897        """
898        Keep track of all boxes received by this test in its capacity as an
899        L{IBoxReceiver} implementor.
900        """
901        self.boxes = []
902        self.data = []
903
904    def startReceivingBoxes(self, sender):
905        """
906        Implement L{IBoxReceiver.startReceivingBoxes} to just remember the
907        value passed in.
908        """
909        self._boxSender = sender
910
911    def ampBoxReceived(self, box):
912        """
913        A box was received by the protocol.
914        """
915        self.boxes.append(box)
916
917    stopReason = None
918
919    def stopReceivingBoxes(self, reason):
920        """
921        Record the reason that we stopped receiving boxes.
922        """
923        self.stopReason = reason
924
925    # fake ITransport
926    def getPeer(self):
927        return "no peer"
928
929    def getHost(self):
930        return "no host"
931
932    def write(self, data):
933        self.assertIsInstance(data, bytes)
934        self.data.append(data)
935
936    def test_startReceivingBoxes(self):
937        """
938        When L{amp.BinaryBoxProtocol} is connected to a transport, it calls
939        C{startReceivingBoxes} on its L{IBoxReceiver} with itself as the
940        L{IBoxSender} parameter.
941        """
942        protocol = amp.BinaryBoxProtocol(self)
943        protocol.makeConnection(None)
944        self.assertIs(self._boxSender, protocol)
945
946    def test_sendBoxInStartReceivingBoxes(self):
947        """
948        The L{IBoxReceiver} which is started when L{amp.BinaryBoxProtocol} is
949        connected to a transport can call C{sendBox} on the L{IBoxSender}
950        passed to it before C{startReceivingBoxes} returns and have that box
951        sent.
952        """
953
954        class SynchronouslySendingReceiver:
955            def startReceivingBoxes(self, sender):
956                sender.sendBox(amp.Box({b"foo": b"bar"}))
957
958        transport = StringTransport()
959        protocol = amp.BinaryBoxProtocol(SynchronouslySendingReceiver())
960        protocol.makeConnection(transport)
961        self.assertEqual(transport.value(), b"\x00\x03foo\x00\x03bar\x00\x00")
962
963    def test_receiveBoxStateMachine(self):
964        """
965        When a binary box protocol receives:
966            * a key
967            * a value
968            * an empty string
969        it should emit a box and send it to its boxReceiver.
970        """
971        a = amp.BinaryBoxProtocol(self)
972        a.stringReceived(b"hello")
973        a.stringReceived(b"world")
974        a.stringReceived(b"")
975        self.assertEqual(self.boxes, [amp.AmpBox(hello=b"world")])
976
977    def test_firstBoxFirstKeyExcessiveLength(self):
978        """
979        L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
980        the first a key it receives is larger than 255.
981        """
982        transport = StringTransport()
983        protocol = amp.BinaryBoxProtocol(self)
984        protocol.makeConnection(transport)
985        protocol.dataReceived(b"\x01\x00")
986        self.assertTrue(transport.disconnecting)
987
988    def test_firstBoxSubsequentKeyExcessiveLength(self):
989        """
990        L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
991        a subsequent key in the first box it receives is larger than 255.
992        """
993        transport = StringTransport()
994        protocol = amp.BinaryBoxProtocol(self)
995        protocol.makeConnection(transport)
996        protocol.dataReceived(b"\x00\x01k\x00\x01v")
997        self.assertFalse(transport.disconnecting)
998        protocol.dataReceived(b"\x01\x00")
999        self.assertTrue(transport.disconnecting)
1000
1001    def test_subsequentBoxFirstKeyExcessiveLength(self):
1002        """
1003        L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
1004        the first key in a subsequent box it receives is larger than 255.
1005        """
1006        transport = StringTransport()
1007        protocol = amp.BinaryBoxProtocol(self)
1008        protocol.makeConnection(transport)
1009        protocol.dataReceived(b"\x00\x01k\x00\x01v\x00\x00")
1010        self.assertFalse(transport.disconnecting)
1011        protocol.dataReceived(b"\x01\x00")
1012        self.assertTrue(transport.disconnecting)
1013
1014    def test_excessiveKeyFailure(self):
1015        """
1016        If L{amp.BinaryBoxProtocol} disconnects because it received a key
1017        length prefix which was too large, the L{IBoxReceiver}'s
1018        C{stopReceivingBoxes} method is called with a L{TooLong} failure.
1019        """
1020        protocol = amp.BinaryBoxProtocol(self)
1021        protocol.makeConnection(StringTransport())
1022        protocol.dataReceived(b"\x01\x00")
1023        protocol.connectionLost(
1024            Failure(error.ConnectionDone("simulated connection done"))
1025        )
1026        self.stopReason.trap(amp.TooLong)
1027        self.assertTrue(self.stopReason.value.isKey)
1028        self.assertFalse(self.stopReason.value.isLocal)
1029        self.assertIsNone(self.stopReason.value.value)
1030        self.assertIsNone(self.stopReason.value.keyName)
1031
1032    def test_unhandledErrorWithTransport(self):
1033        """
1034        L{amp.BinaryBoxProtocol.unhandledError} logs the failure passed to it
1035        and disconnects its transport.
1036        """
1037        transport = StringTransport()
1038        protocol = amp.BinaryBoxProtocol(self)
1039        protocol.makeConnection(transport)
1040        protocol.unhandledError(Failure(RuntimeError("Fake error")))
1041        self.assertEqual(1, len(self.flushLoggedErrors(RuntimeError)))
1042        self.assertTrue(transport.disconnecting)
1043
1044    def test_unhandledErrorWithoutTransport(self):
1045        """
1046        L{amp.BinaryBoxProtocol.unhandledError} completes without error when
1047        there is no associated transport.
1048        """
1049        protocol = amp.BinaryBoxProtocol(self)
1050        protocol.makeConnection(StringTransport())
1051        protocol.connectionLost(Failure(Exception("Simulated")))
1052        protocol.unhandledError(Failure(RuntimeError("Fake error")))
1053        self.assertEqual(1, len(self.flushLoggedErrors(RuntimeError)))
1054
1055    def test_receiveBoxData(self):
1056        """
1057        When a binary box protocol receives the serialized form of an AMP box,
1058        it should emit a similar box to its boxReceiver.
1059        """
1060        a = amp.BinaryBoxProtocol(self)
1061        a.dataReceived(
1062            amp.Box(
1063                {b"testKey": b"valueTest", b"anotherKey": b"anotherValue"}
1064            ).serialize()
1065        )
1066        self.assertEqual(
1067            self.boxes,
1068            [amp.Box({b"testKey": b"valueTest", b"anotherKey": b"anotherValue"})],
1069        )
1070
1071    def test_receiveLongerBoxData(self):
1072        """
1073        An L{amp.BinaryBoxProtocol} can receive serialized AMP boxes with
1074        values of up to (2 ** 16 - 1) bytes.
1075        """
1076        length = 2 ** 16 - 1
1077        value = b"x" * length
1078        transport = StringTransport()
1079        protocol = amp.BinaryBoxProtocol(self)
1080        protocol.makeConnection(transport)
1081        protocol.dataReceived(amp.Box({"k": value}).serialize())
1082        self.assertEqual(self.boxes, [amp.Box({"k": value})])
1083        self.assertFalse(transport.disconnecting)
1084
1085    def test_sendBox(self):
1086        """
1087        When a binary box protocol sends a box, it should emit the serialized
1088        bytes of that box to its transport.
1089        """
1090        a = amp.BinaryBoxProtocol(self)
1091        a.makeConnection(self)
1092        aBox = amp.Box({b"testKey": b"valueTest", b"someData": b"hello"})
1093        a.makeConnection(self)
1094        a.sendBox(aBox)
1095        self.assertEqual(b"".join(self.data), aBox.serialize())
1096
1097    def test_connectionLostStopSendingBoxes(self):
1098        """
1099        When a binary box protocol loses its connection, it should notify its
1100        box receiver that it has stopped receiving boxes.
1101        """
1102        a = amp.BinaryBoxProtocol(self)
1103        a.makeConnection(self)
1104        connectionFailure = Failure(RuntimeError())
1105        a.connectionLost(connectionFailure)
1106        self.assertIs(self.stopReason, connectionFailure)
1107
1108    def test_protocolSwitch(self):
1109        """
1110        L{BinaryBoxProtocol} has the capacity to switch to a different protocol
1111        on a box boundary.  When a protocol is in the process of switching, it
1112        cannot receive traffic.
1113        """
1114        otherProto = TestProto(None, b"outgoing data")
1115        test = self
1116
1117        class SwitchyReceiver:
1118            switched = False
1119
1120            def startReceivingBoxes(self, sender):
1121                pass
1122
1123            def ampBoxReceived(self, box):
1124                test.assertFalse(self.switched, "Should only receive one box!")
1125                self.switched = True
1126                a._lockForSwitch()
1127                a._switchTo(otherProto)
1128
1129        a = amp.BinaryBoxProtocol(SwitchyReceiver())
1130        anyOldBox = amp.Box({b"include": b"lots", b"of": b"data"})
1131        a.makeConnection(self)
1132        # Include a 0-length box at the beginning of the next protocol's data,
1133        # to make sure that AMP doesn't eat the data or try to deliver extra
1134        # boxes either...
1135        moreThanOneBox = anyOldBox.serialize() + b"\x00\x00Hello, world!"
1136        a.dataReceived(moreThanOneBox)
1137        self.assertIs(otherProto.transport, self)
1138        self.assertEqual(b"".join(otherProto.data), b"\x00\x00Hello, world!")
1139        self.assertEqual(self.data, [b"outgoing data"])
1140        a.dataReceived(b"more data")
1141        self.assertEqual(b"".join(otherProto.data), b"\x00\x00Hello, world!more data")
1142        self.assertRaises(amp.ProtocolSwitched, a.sendBox, anyOldBox)
1143
1144    def test_protocolSwitchEmptyBuffer(self):
1145        """
1146        After switching to a different protocol, if no extra bytes beyond
1147        the switch box were delivered, an empty string is not passed to the
1148        switched protocol's C{dataReceived} method.
1149        """
1150        a = amp.BinaryBoxProtocol(self)
1151        a.makeConnection(self)
1152        otherProto = TestProto(None, b"")
1153        a._switchTo(otherProto)
1154        self.assertEqual(otherProto.data, [])
1155
1156    def test_protocolSwitchInvalidStates(self):
1157        """
1158        In order to make sure the protocol never gets any invalid data sent
1159        into the middle of a box, it must be locked for switching before it is
1160        switched.  It can only be unlocked if the switch failed, and attempting
1161        to send a box while it is locked should raise an exception.
1162        """
1163        a = amp.BinaryBoxProtocol(self)
1164        a.makeConnection(self)
1165        sampleBox = amp.Box({b"some": b"data"})
1166        a._lockForSwitch()
1167        self.assertRaises(amp.ProtocolSwitched, a.sendBox, sampleBox)
1168        a._unlockFromSwitch()
1169        a.sendBox(sampleBox)
1170        self.assertEqual(b"".join(self.data), sampleBox.serialize())
1171        a._lockForSwitch()
1172        otherProto = TestProto(None, b"outgoing data")
1173        a._switchTo(otherProto)
1174        self.assertRaises(amp.ProtocolSwitched, a._unlockFromSwitch)
1175
1176    def test_protocolSwitchLoseConnection(self):
1177        """
1178        When the protocol is switched, it should notify its nested protocol of
1179        disconnection.
1180        """
1181
1182        class Loser(protocol.Protocol):
1183            reason = None
1184
1185            def connectionLost(self, reason):
1186                self.reason = reason
1187
1188        connectionLoser = Loser()
1189        a = amp.BinaryBoxProtocol(self)
1190        a.makeConnection(self)
1191        a._lockForSwitch()
1192        a._switchTo(connectionLoser)
1193        connectionFailure = Failure(RuntimeError())
1194        a.connectionLost(connectionFailure)
1195        self.assertEqual(connectionLoser.reason, connectionFailure)
1196
1197    def test_protocolSwitchLoseClientConnection(self):
1198        """
1199        When the protocol is switched, it should notify its nested client
1200        protocol factory of disconnection.
1201        """
1202
1203        class ClientLoser:
1204            reason = None
1205
1206            def clientConnectionLost(self, connector, reason):
1207                self.reason = reason
1208
1209        a = amp.BinaryBoxProtocol(self)
1210        connectionLoser = protocol.Protocol()
1211        clientLoser = ClientLoser()
1212        a.makeConnection(self)
1213        a._lockForSwitch()
1214        a._switchTo(connectionLoser, clientLoser)
1215        connectionFailure = Failure(RuntimeError())
1216        a.connectionLost(connectionFailure)
1217        self.assertEqual(clientLoser.reason, connectionFailure)
1218
1219
1220class AMPTests(TestCase):
1221    def test_interfaceDeclarations(self):
1222        """
1223        The classes in the amp module ought to implement the interfaces that
1224        are declared for their benefit.
1225        """
1226        for interface, implementation in [
1227            (amp.IBoxSender, amp.BinaryBoxProtocol),
1228            (amp.IBoxReceiver, amp.BoxDispatcher),
1229            (amp.IResponderLocator, amp.CommandLocator),
1230            (amp.IResponderLocator, amp.SimpleStringLocator),
1231            (amp.IBoxSender, amp.AMP),
1232            (amp.IBoxReceiver, amp.AMP),
1233            (amp.IResponderLocator, amp.AMP),
1234        ]:
1235            self.assertTrue(
1236                interface.implementedBy(implementation),
1237                f"{implementation} does not implements({interface})",
1238            )
1239
1240    def test_helloWorld(self):
1241        """
1242        Verify that a simple command can be sent and its response received with
1243        the simple low-level string-based API.
1244        """
1245        c, s, p = connectedServerAndClient()
1246        L = []
1247        HELLO = b"world"
1248        c.sendHello(HELLO).addCallback(L.append)
1249        p.flush()
1250        self.assertEqual(L[0][b"hello"], HELLO)
1251
1252    def test_wireFormatRoundTrip(self):
1253        """
1254        Verify that mixed-case, underscored and dashed arguments are mapped to
1255        their python names properly.
1256        """
1257        c, s, p = connectedServerAndClient()
1258        L = []
1259        HELLO = b"world"
1260        c.sendHello(HELLO).addCallback(L.append)
1261        p.flush()
1262        self.assertEqual(L[0][b"hello"], HELLO)
1263
1264    def test_helloWorldUnicode(self):
1265        """
1266        Verify that unicode arguments can be encoded and decoded.
1267        """
1268        c, s, p = connectedServerAndClient(
1269            ServerClass=SimpleSymmetricCommandProtocol,
1270            ClientClass=SimpleSymmetricCommandProtocol,
1271        )
1272        L = []
1273        HELLO = b"world"
1274        HELLO_UNICODE = "wor\u1234ld"
1275        c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append)
1276        p.flush()
1277        self.assertEqual(L[0]["hello"], HELLO)
1278        self.assertEqual(L[0]["Print"], HELLO_UNICODE)
1279
1280    def test_callRemoteStringRequiresAnswerFalse(self):
1281        """
1282        L{BoxDispatcher.callRemoteString} returns L{None} if C{requiresAnswer}
1283        is C{False}.
1284        """
1285        c, s, p = connectedServerAndClient()
1286        ret = c.callRemoteString(b"WTF", requiresAnswer=False)
1287        self.assertIsNone(ret)
1288
1289    def test_unknownCommandLow(self):
1290        """
1291        Verify that unknown commands using low-level APIs will be rejected with an
1292        error, but will NOT terminate the connection.
1293        """
1294        c, s, p = connectedServerAndClient()
1295        L = []
1296
1297        def clearAndAdd(e):
1298            """
1299            You can't propagate the error...
1300            """
1301            e.trap(amp.UnhandledCommand)
1302            return "OK"
1303
1304        c.callRemoteString(b"WTF").addErrback(clearAndAdd).addCallback(L.append)
1305        p.flush()
1306        self.assertEqual(L.pop(), "OK")
1307        HELLO = b"world"
1308        c.sendHello(HELLO).addCallback(L.append)
1309        p.flush()
1310        self.assertEqual(L[0][b"hello"], HELLO)
1311
1312    def test_unknownCommandHigh(self):
1313        """
1314        Verify that unknown commands using high-level APIs will be rejected with an
1315        error, but will NOT terminate the connection.
1316        """
1317        c, s, p = connectedServerAndClient()
1318        L = []
1319
1320        def clearAndAdd(e):
1321            """
1322            You can't propagate the error...
1323            """
1324            e.trap(amp.UnhandledCommand)
1325            return "OK"
1326
1327        c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append)
1328        p.flush()
1329        self.assertEqual(L.pop(), "OK")
1330        HELLO = b"world"
1331        c.sendHello(HELLO).addCallback(L.append)
1332        p.flush()
1333        self.assertEqual(L[0][b"hello"], HELLO)
1334
1335    def test_brokenReturnValue(self):
1336        """
1337        It can be very confusing if you write some code which responds to a
1338        command, but gets the return value wrong.  Most commonly you end up
1339        returning None instead of a dictionary.
1340
1341        Verify that if that happens, the framework logs a useful error.
1342        """
1343        L = []
1344        SimpleSymmetricCommandProtocol().dispatchCommand(
1345            amp.AmpBox(_command=BrokenReturn.commandName)
1346        ).addErrback(L.append)
1347        L[0].trap(amp.BadLocalReturn)
1348        self.failUnlessIn("None", repr(L[0].value))
1349
1350    def test_unknownArgument(self):
1351        """
1352        Verify that unknown arguments are ignored, and not passed to a Python
1353        function which can't accept them.
1354        """
1355        c, s, p = connectedServerAndClient(
1356            ServerClass=SimpleSymmetricCommandProtocol,
1357            ClientClass=SimpleSymmetricCommandProtocol,
1358        )
1359        L = []
1360        HELLO = b"world"
1361        # c.sendHello(HELLO).addCallback(L.append)
1362        c.callRemote(
1363            FutureHello, hello=HELLO, bonus=b"I'm not in the book!"
1364        ).addCallback(L.append)
1365        p.flush()
1366        self.assertEqual(L[0]["hello"], HELLO)
1367
1368    def test_simpleReprs(self):
1369        """
1370        Verify that the various Box objects repr properly, for debugging.
1371        """
1372        self.assertEqual(type(repr(amp._SwitchBox("a"))), str)
1373        self.assertEqual(type(repr(amp.QuitBox())), str)
1374        self.assertEqual(type(repr(amp.AmpBox())), str)
1375        self.assertIn("AmpBox", repr(amp.AmpBox()))
1376
1377    def test_innerProtocolInRepr(self):
1378        """
1379        Verify that L{AMP} objects output their innerProtocol when set.
1380        """
1381        otherProto = TestProto(None, b"outgoing data")
1382        a = amp.AMP()
1383        a.innerProtocol = otherProto
1384
1385        self.assertEqual(
1386            repr(a),
1387            "<AMP inner <TestProto #%d> at 0x%x>" % (otherProto.instanceId, id(a)),
1388        )
1389
1390    def test_innerProtocolNotInRepr(self):
1391        """
1392        Verify that L{AMP} objects do not output 'inner' when no innerProtocol
1393        is set.
1394        """
1395        a = amp.AMP()
1396        self.assertEqual(repr(a), f"<AMP at 0x{id(a):x}>")
1397
1398    @skipIf(skipSSL, "SSL not available")
1399    def test_simpleSSLRepr(self):
1400        """
1401        L{amp._TLSBox.__repr__} returns a string.
1402        """
1403        self.assertEqual(type(repr(amp._TLSBox())), str)
1404
1405    def test_keyTooLong(self):
1406        """
1407        Verify that a key that is too long will immediately raise a synchronous
1408        exception.
1409        """
1410        c, s, p = connectedServerAndClient()
1411        x = "H" * (0xFF + 1)
1412        tl = self.assertRaises(amp.TooLong, c.callRemoteString, b"Hello", **{x: b"hi"})
1413        self.assertTrue(tl.isKey)
1414        self.assertTrue(tl.isLocal)
1415        self.assertIsNone(tl.keyName)
1416        self.assertEqual(tl.value, x.encode("ascii"))
1417        self.assertIn(str(len(x)), repr(tl))
1418        self.assertIn("key", repr(tl))
1419
1420    def test_valueTooLong(self):
1421        """
1422        Verify that attempting to send value longer than 64k will immediately
1423        raise an exception.
1424        """
1425        c, s, p = connectedServerAndClient()
1426        x = b"H" * (0xFFFF + 1)
1427        tl = self.assertRaises(amp.TooLong, c.sendHello, x)
1428        p.flush()
1429        self.assertFalse(tl.isKey)
1430        self.assertTrue(tl.isLocal)
1431        self.assertEqual(tl.keyName, b"hello")
1432        self.failUnlessIdentical(tl.value, x)
1433        self.assertIn(str(len(x)), repr(tl))
1434        self.assertIn("value", repr(tl))
1435        self.assertIn("hello", repr(tl))
1436
1437    def test_helloWorldCommand(self):
1438        """
1439        Verify that a simple command can be sent and its response received with
1440        the high-level value parsing API.
1441        """
1442        c, s, p = connectedServerAndClient(
1443            ServerClass=SimpleSymmetricCommandProtocol,
1444            ClientClass=SimpleSymmetricCommandProtocol,
1445        )
1446        L = []
1447        HELLO = b"world"
1448        c.sendHello(HELLO).addCallback(L.append)
1449        p.flush()
1450        self.assertEqual(L[0]["hello"], HELLO)
1451
1452    def test_helloErrorHandling(self):
1453        """
1454        Verify that if a known error type is raised and handled, it will be
1455        properly relayed to the other end of the connection and translated into
1456        an exception, and no error will be logged.
1457        """
1458        L = []
1459        c, s, p = connectedServerAndClient(
1460            ServerClass=SimpleSymmetricCommandProtocol,
1461            ClientClass=SimpleSymmetricCommandProtocol,
1462        )
1463        HELLO = b"fuck you"
1464        c.sendHello(HELLO).addErrback(L.append)
1465        p.flush()
1466        L[0].trap(UnfriendlyGreeting)
1467        self.assertEqual(str(L[0].value), "Don't be a dick.")
1468
1469    def test_helloFatalErrorHandling(self):
1470        """
1471        Verify that if a known, fatal error type is raised and handled, it will
1472        be properly relayed to the other end of the connection and translated
1473        into an exception, no error will be logged, and the connection will be
1474        terminated.
1475        """
1476        L = []
1477        c, s, p = connectedServerAndClient(
1478            ServerClass=SimpleSymmetricCommandProtocol,
1479            ClientClass=SimpleSymmetricCommandProtocol,
1480        )
1481        HELLO = b"die"
1482        c.sendHello(HELLO).addErrback(L.append)
1483        p.flush()
1484        L.pop().trap(DeathThreat)
1485        c.sendHello(HELLO).addErrback(L.append)
1486        p.flush()
1487        L.pop().trap(error.ConnectionDone)
1488
1489    def test_helloNoErrorHandling(self):
1490        """
1491        Verify that if an unknown error type is raised, it will be relayed to
1492        the other end of the connection and translated into an exception, it
1493        will be logged, and then the connection will be dropped.
1494        """
1495        L = []
1496        c, s, p = connectedServerAndClient(
1497            ServerClass=SimpleSymmetricCommandProtocol,
1498            ClientClass=SimpleSymmetricCommandProtocol,
1499        )
1500        HELLO = THING_I_DONT_UNDERSTAND
1501        c.sendHello(HELLO).addErrback(L.append)
1502        p.flush()
1503        ure = L.pop()
1504        ure.trap(amp.UnknownRemoteError)
1505        c.sendHello(HELLO).addErrback(L.append)
1506        cl = L.pop()
1507        cl.trap(error.ConnectionDone)
1508        # The exception should have been logged.
1509        self.assertTrue(self.flushLoggedErrors(ThingIDontUnderstandError))
1510
1511    def test_lateAnswer(self):
1512        """
1513        Verify that a command that does not get answered until after the
1514        connection terminates will not cause any errors.
1515        """
1516        c, s, p = connectedServerAndClient(
1517            ServerClass=SimpleSymmetricCommandProtocol,
1518            ClientClass=SimpleSymmetricCommandProtocol,
1519        )
1520        L = []
1521        c.callRemote(WaitForever).addErrback(L.append)
1522        p.flush()
1523        self.assertEqual(L, [])
1524        s.transport.loseConnection()
1525        p.flush()
1526        L.pop().trap(error.ConnectionDone)
1527        # Just make sure that it doesn't error...
1528        s.waiting.callback({})
1529        return s.waiting
1530
1531    def test_requiresNoAnswer(self):
1532        """
1533        Verify that a command that requires no answer is run.
1534        """
1535        c, s, p = connectedServerAndClient(
1536            ServerClass=SimpleSymmetricCommandProtocol,
1537            ClientClass=SimpleSymmetricCommandProtocol,
1538        )
1539        HELLO = b"world"
1540        c.callRemote(NoAnswerHello, hello=HELLO)
1541        p.flush()
1542        self.assertTrue(s.greeted)
1543
1544    def test_requiresNoAnswerFail(self):
1545        """
1546        Verify that commands sent after a failed no-answer request do not complete.
1547        """
1548        L = []
1549        c, s, p = connectedServerAndClient(
1550            ServerClass=SimpleSymmetricCommandProtocol,
1551            ClientClass=SimpleSymmetricCommandProtocol,
1552        )
1553        HELLO = b"fuck you"
1554        c.callRemote(NoAnswerHello, hello=HELLO)
1555        p.flush()
1556        # This should be logged locally.
1557        self.assertTrue(self.flushLoggedErrors(amp.RemoteAmpError))
1558        HELLO = b"world"
1559        c.callRemote(Hello, hello=HELLO).addErrback(L.append)
1560        p.flush()
1561        L.pop().trap(error.ConnectionDone)
1562        self.assertFalse(s.greeted)
1563
1564    def test_requiresNoAnswerAfterFail(self):
1565        """
1566        No-answer commands sent after the connection has been torn down do not
1567        return a L{Deferred}.
1568        """
1569        c, s, p = connectedServerAndClient(
1570            ServerClass=SimpleSymmetricCommandProtocol,
1571            ClientClass=SimpleSymmetricCommandProtocol,
1572        )
1573        c.transport.loseConnection()
1574        p.flush()
1575        result = c.callRemote(NoAnswerHello, hello=b"ignored")
1576        self.assertIs(result, None)
1577
1578    def test_noAnswerResponderBadAnswer(self):
1579        """
1580        Verify that responders of requiresAnswer=False commands have to return
1581        a dictionary anyway.
1582
1583        (requiresAnswer is a hint from the _client_ - the server may be called
1584        upon to answer commands in any case, if the client wants to know when
1585        they complete.)
1586        """
1587        c, s, p = connectedServerAndClient(
1588            ServerClass=BadNoAnswerCommandProtocol,
1589            ClientClass=SimpleSymmetricCommandProtocol,
1590        )
1591        c.callRemote(NoAnswerHello, hello=b"hello")
1592        p.flush()
1593        le = self.flushLoggedErrors(amp.BadLocalReturn)
1594        self.assertEqual(len(le), 1)
1595
1596    def test_noAnswerResponderAskedForAnswer(self):
1597        """
1598        Verify that responders with requiresAnswer=False will actually respond
1599        if the client sets requiresAnswer=True.  In other words, verify that
1600        requiresAnswer is a hint honored only by the client.
1601        """
1602        c, s, p = connectedServerAndClient(
1603            ServerClass=NoAnswerCommandProtocol,
1604            ClientClass=SimpleSymmetricCommandProtocol,
1605        )
1606        L = []
1607        c.callRemote(Hello, hello=b"Hello!").addCallback(L.append)
1608        p.flush()
1609        self.assertEqual(len(L), 1)
1610        self.assertEqual(
1611            L, [dict(hello=b"Hello!-noanswer", Print=None)]
1612        )  # Optional response argument
1613
1614    def test_ampListCommand(self):
1615        """
1616        Test encoding of an argument that uses the AmpList encoding.
1617        """
1618        c, s, p = connectedServerAndClient(
1619            ServerClass=SimpleSymmetricCommandProtocol,
1620            ClientClass=SimpleSymmetricCommandProtocol,
1621        )
1622        L = []
1623        c.callRemote(GetList, length=10).addCallback(L.append)
1624        p.flush()
1625        values = L.pop().get("body")
1626        self.assertEqual(values, [{"x": 1}] * 10)
1627
1628    def test_optionalAmpListOmitted(self):
1629        """
1630        Sending a command with an omitted AmpList argument that is
1631        designated as optional does not raise an InvalidSignature error.
1632        """
1633        c, s, p = connectedServerAndClient(
1634            ServerClass=SimpleSymmetricCommandProtocol,
1635            ClientClass=SimpleSymmetricCommandProtocol,
1636        )
1637        L = []
1638        c.callRemote(DontRejectMe, magicWord="please").addCallback(L.append)
1639        p.flush()
1640        response = L.pop().get("response")
1641        self.assertEqual(response, "list omitted")
1642
1643    def test_optionalAmpListPresent(self):
1644        """
1645        Sanity check that optional AmpList arguments are processed normally.
1646        """
1647        c, s, p = connectedServerAndClient(
1648            ServerClass=SimpleSymmetricCommandProtocol,
1649            ClientClass=SimpleSymmetricCommandProtocol,
1650        )
1651        L = []
1652        c.callRemote(
1653            DontRejectMe, magicWord="please", list=[{"name": "foo"}]
1654        ).addCallback(L.append)
1655        p.flush()
1656        response = L.pop().get("response")
1657        self.assertEqual(response, "foo accepted")
1658
1659    def test_failEarlyOnArgSending(self):
1660        """
1661        Verify that if we pass an invalid argument list (omitting an argument),
1662        an exception will be raised.
1663        """
1664        self.assertRaises(amp.InvalidSignature, Hello)
1665
1666    def test_doubleProtocolSwitch(self):
1667        """
1668        As a debugging aid, a protocol system should raise a
1669        L{ProtocolSwitched} exception when asked to switch a protocol that is
1670        already switched.
1671        """
1672        serverDeferred = defer.Deferred()
1673        serverProto = SimpleSymmetricCommandProtocol(serverDeferred)
1674        clientDeferred = defer.Deferred()
1675        clientProto = SimpleSymmetricCommandProtocol(clientDeferred)
1676        c, s, p = connectedServerAndClient(
1677            ServerClass=lambda: serverProto, ClientClass=lambda: clientProto
1678        )
1679
1680        def switched(result):
1681            self.assertRaises(amp.ProtocolSwitched, c.switchToTestProtocol)
1682            self.testSucceeded = True
1683
1684        c.switchToTestProtocol().addCallback(switched)
1685        p.flush()
1686        self.assertTrue(self.testSucceeded)
1687
1688    def test_protocolSwitch(
1689        self,
1690        switcher=SimpleSymmetricCommandProtocol,
1691        spuriousTraffic=False,
1692        spuriousError=False,
1693    ):
1694        """
1695        Verify that it is possible to switch to another protocol mid-connection and
1696        send data to it successfully.
1697        """
1698        self.testSucceeded = False
1699
1700        serverDeferred = defer.Deferred()
1701        serverProto = switcher(serverDeferred)
1702        clientDeferred = defer.Deferred()
1703        clientProto = switcher(clientDeferred)
1704        c, s, p = connectedServerAndClient(
1705            ServerClass=lambda: serverProto, ClientClass=lambda: clientProto
1706        )
1707
1708        if spuriousTraffic:
1709            wfdr = []  # remote
1710            c.callRemote(WaitForever).addErrback(wfdr.append)
1711        switchDeferred = c.switchToTestProtocol()
1712        if spuriousTraffic:
1713            self.assertRaises(amp.ProtocolSwitched, c.sendHello, b"world")
1714
1715        def cbConnsLost(info):
1716            ((serverSuccess, serverData), (clientSuccess, clientData)) = info
1717            self.assertTrue(serverSuccess)
1718            self.assertTrue(clientSuccess)
1719            self.assertEqual(b"".join(serverData), SWITCH_CLIENT_DATA)
1720            self.assertEqual(b"".join(clientData), SWITCH_SERVER_DATA)
1721            self.testSucceeded = True
1722
1723        def cbSwitch(proto):
1724            return defer.DeferredList([serverDeferred, clientDeferred]).addCallback(
1725                cbConnsLost
1726            )
1727
1728        switchDeferred.addCallback(cbSwitch)
1729        p.flush()
1730        if serverProto.maybeLater is not None:
1731            serverProto.maybeLater.callback(serverProto.maybeLaterProto)
1732            p.flush()
1733        if spuriousTraffic:
1734            # switch is done here; do this here to make sure that if we're
1735            # going to corrupt the connection, we do it before it's closed.
1736            if spuriousError:
1737                s.waiting.errback(
1738                    amp.RemoteAmpError(
1739                        b"SPURIOUS", "Here's some traffic in the form of an error."
1740                    )
1741                )
1742            else:
1743                s.waiting.callback({})
1744            p.flush()
1745        c.transport.loseConnection()  # close it
1746        p.flush()
1747        self.assertTrue(self.testSucceeded)
1748
1749    def test_protocolSwitchDeferred(self):
1750        """
1751        Verify that protocol-switching even works if the value returned from
1752        the command that does the switch is deferred.
1753        """
1754        return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtocol)
1755
1756    def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol):
1757        """
1758        Verify that if we try to switch protocols and it fails, the connection
1759        stays up and we can go back to speaking AMP.
1760        """
1761        self.testSucceeded = False
1762
1763        serverDeferred = defer.Deferred()
1764        serverProto = switcher(serverDeferred)
1765        clientDeferred = defer.Deferred()
1766        clientProto = switcher(clientDeferred)
1767        c, s, p = connectedServerAndClient(
1768            ServerClass=lambda: serverProto, ClientClass=lambda: clientProto
1769        )
1770        L = []
1771        c.switchToTestProtocol(fail=True).addErrback(L.append)
1772        p.flush()
1773        L.pop().trap(UnknownProtocol)
1774        self.assertFalse(self.testSucceeded)
1775        # It's a known error, so let's send a "hello" on the same connection;
1776        # it should work.
1777        c.sendHello(b"world").addCallback(L.append)
1778        p.flush()
1779        self.assertEqual(L.pop()["hello"], b"world")
1780
1781    def test_trafficAfterSwitch(self):
1782        """
1783        Verify that attempts to send traffic after a switch will not corrupt
1784        the nested protocol.
1785        """
1786        return self.test_protocolSwitch(spuriousTraffic=True)
1787
1788    def test_errorAfterSwitch(self):
1789        """
1790        Returning an error after a protocol switch should record the underlying
1791        error.
1792        """
1793        return self.test_protocolSwitch(spuriousTraffic=True, spuriousError=True)
1794
1795    def test_quitBoxQuits(self):
1796        """
1797        Verify that commands with a responseType of QuitBox will in fact
1798        terminate the connection.
1799        """
1800        c, s, p = connectedServerAndClient(
1801            ServerClass=SimpleSymmetricCommandProtocol,
1802            ClientClass=SimpleSymmetricCommandProtocol,
1803        )
1804
1805        L = []
1806        HELLO = b"world"
1807        GOODBYE = b"everyone"
1808        c.sendHello(HELLO).addCallback(L.append)
1809        p.flush()
1810        self.assertEqual(L.pop()["hello"], HELLO)
1811        c.callRemote(Goodbye).addCallback(L.append)
1812        p.flush()
1813        self.assertEqual(L.pop()["goodbye"], GOODBYE)
1814        c.sendHello(HELLO).addErrback(L.append)
1815        L.pop().trap(error.ConnectionDone)
1816
1817    def test_basicLiteralEmit(self):
1818        """
1819        Verify that the command dictionaries for a callRemoteN look correct
1820        after being serialized and parsed.
1821        """
1822        c, s, p = connectedServerAndClient()
1823        L = []
1824        s.ampBoxReceived = L.append
1825        c.callRemote(
1826            Hello,
1827            hello=b"hello test",
1828            mixedCase=b"mixed case arg test",
1829            dash_arg=b"x",
1830            underscore_arg=b"y",
1831        )
1832        p.flush()
1833        self.assertEqual(len(L), 1)
1834        for k, v in [
1835            (b"_command", Hello.commandName),
1836            (b"hello", b"hello test"),
1837            (b"mixedCase", b"mixed case arg test"),
1838            (b"dash-arg", b"x"),
1839            (b"underscore_arg", b"y"),
1840        ]:
1841            self.assertEqual(L[-1].pop(k), v)
1842        L[-1].pop(b"_ask")
1843        self.assertEqual(L[-1], {})
1844
1845    def test_basicStructuredEmit(self):
1846        """
1847        Verify that a call similar to basicLiteralEmit's is handled properly with
1848        high-level quoting and passing to Python methods, and that argument
1849        names are correctly handled.
1850        """
1851        L = []
1852
1853        class StructuredHello(amp.AMP):
1854            def h(self, *a, **k):
1855                L.append((a, k))
1856                return dict(hello=b"aaa")
1857
1858            Hello.responder(h)
1859
1860        c, s, p = connectedServerAndClient(ServerClass=StructuredHello)
1861        c.callRemote(
1862            Hello,
1863            hello=b"hello test",
1864            mixedCase=b"mixed case arg test",
1865            dash_arg=b"x",
1866            underscore_arg=b"y",
1867        ).addCallback(L.append)
1868        p.flush()
1869        self.assertEqual(len(L), 2)
1870        self.assertEqual(
1871            L[0],
1872            (
1873                (),
1874                dict(
1875                    hello=b"hello test",
1876                    mixedCase=b"mixed case arg test",
1877                    dash_arg=b"x",
1878                    underscore_arg=b"y",
1879                    From=s.transport.getPeer(),
1880                    # XXX - should optional arguments just not be passed?
1881                    # passing None seems a little odd, looking at the way it
1882                    # turns out here... -glyph
1883                    Print=None,
1884                    optional=None,
1885                ),
1886            ),
1887        )
1888        self.assertEqual(L[1], dict(Print=None, hello=b"aaa"))
1889
1890
1891class PretendRemoteCertificateAuthority:
1892    def checkIsPretendRemote(self):
1893        return True
1894
1895
1896class IOSimCert:
1897    verifyCount = 0
1898
1899    def options(self, *ign):
1900        return self
1901
1902    def iosimVerify(self, otherCert):
1903        """
1904        This isn't a real certificate, and wouldn't work on a real socket, but
1905        iosim specifies a different API so that we don't have to do any crypto
1906        math to demonstrate that the right functions get called in the right
1907        places.
1908        """
1909        assert otherCert is self
1910        self.verifyCount += 1
1911        return True
1912
1913
1914class OKCert(IOSimCert):
1915    def options(self, x):
1916        assert x.checkIsPretendRemote()
1917        return self
1918
1919
1920class GrumpyCert(IOSimCert):
1921    def iosimVerify(self, otherCert):
1922        self.verifyCount += 1
1923        return False
1924
1925
1926class DroppyCert(IOSimCert):
1927    def __init__(self, toDrop):
1928        self.toDrop = toDrop
1929
1930    def iosimVerify(self, otherCert):
1931        self.verifyCount += 1
1932        self.toDrop.loseConnection()
1933        return True
1934
1935
1936class SecurableProto(FactoryNotifier):
1937
1938    factory = None
1939
1940    def verifyFactory(self):
1941        return [PretendRemoteCertificateAuthority()]
1942
1943    def getTLSVars(self):
1944        cert = self.certFactory()
1945        verify = self.verifyFactory()
1946        return dict(tls_localCertificate=cert, tls_verifyAuthorities=verify)
1947
1948    amp.StartTLS.responder(getTLSVars)
1949
1950
1951@skipIf(skipSSL, "SSL not available")
1952@skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor")
1953class TLSTests(TestCase):
1954    def test_startingTLS(self):
1955        """
1956        Verify that starting TLS and succeeding at handshaking sends all the
1957        notifications to all the right places.
1958        """
1959        cli, svr, p = connectedServerAndClient(
1960            ServerClass=SecurableProto, ClientClass=SecurableProto
1961        )
1962
1963        okc = OKCert()
1964        svr.certFactory = lambda: okc
1965
1966        cli.callRemote(
1967            amp.StartTLS,
1968            tls_localCertificate=okc,
1969            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()],
1970        )
1971
1972        # let's buffer something to be delivered securely
1973        L = []
1974        cli.callRemote(SecuredPing).addCallback(L.append)
1975        p.flush()
1976        # once for client once for server
1977        self.assertEqual(okc.verifyCount, 2)
1978        L = []
1979        cli.callRemote(SecuredPing).addCallback(L.append)
1980        p.flush()
1981        self.assertEqual(L[0], {"pinged": True})
1982
1983    def test_startTooManyTimes(self):
1984        """
1985        Verify that the protocol will complain if we attempt to renegotiate TLS,
1986        which we don't support.
1987        """
1988        cli, svr, p = connectedServerAndClient(
1989            ServerClass=SecurableProto, ClientClass=SecurableProto
1990        )
1991
1992        okc = OKCert()
1993        svr.certFactory = lambda: okc
1994
1995        cli.callRemote(
1996            amp.StartTLS,
1997            tls_localCertificate=okc,
1998            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()],
1999        )
2000        p.flush()
2001        cli.noPeerCertificate = True  # this is totally fake
2002        self.assertRaises(
2003            amp.OnlyOneTLS,
2004            cli.callRemote,
2005            amp.StartTLS,
2006            tls_localCertificate=okc,
2007            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()],
2008        )
2009
2010    def test_negotiationFailed(self):
2011        """
2012        Verify that starting TLS and failing on both sides at handshaking sends
2013        notifications to all the right places and terminates the connection.
2014        """
2015
2016        badCert = GrumpyCert()
2017
2018        cli, svr, p = connectedServerAndClient(
2019            ServerClass=SecurableProto, ClientClass=SecurableProto
2020        )
2021        svr.certFactory = lambda: badCert
2022
2023        cli.callRemote(amp.StartTLS, tls_localCertificate=badCert)
2024
2025        p.flush()
2026        # once for client once for server - but both fail
2027        self.assertEqual(badCert.verifyCount, 2)
2028        d = cli.callRemote(SecuredPing)
2029        p.flush()
2030        self.assertFailure(d, iosim.NativeOpenSSLError)
2031
2032    def test_negotiationFailedByClosing(self):
2033        """
2034        Verify that starting TLS and failing by way of a lost connection
2035        notices that it is probably an SSL problem.
2036        """
2037
2038        cli, svr, p = connectedServerAndClient(
2039            ServerClass=SecurableProto, ClientClass=SecurableProto
2040        )
2041        droppyCert = DroppyCert(svr.transport)
2042        svr.certFactory = lambda: droppyCert
2043
2044        cli.callRemote(amp.StartTLS, tls_localCertificate=droppyCert)
2045
2046        p.flush()
2047
2048        self.assertEqual(droppyCert.verifyCount, 2)
2049
2050        d = cli.callRemote(SecuredPing)
2051        p.flush()
2052
2053        # it might be a good idea to move this exception somewhere more
2054        # reasonable.
2055        self.assertFailure(d, error.PeerVerifyError)
2056
2057
2058class TLSNotAvailableTests(TestCase):
2059    """
2060    Tests what happened when ssl is not available in current installation.
2061    """
2062
2063    def setUp(self):
2064        """
2065        Disable ssl in amp.
2066        """
2067        self.ssl = amp.ssl
2068        amp.ssl = None
2069
2070    def tearDown(self):
2071        """
2072        Restore ssl module.
2073        """
2074        amp.ssl = self.ssl
2075
2076    def test_callRemoteError(self):
2077        """
2078        Check that callRemote raises an exception when called with a
2079        L{amp.StartTLS}.
2080        """
2081        cli, svr, p = connectedServerAndClient(
2082            ServerClass=SecurableProto, ClientClass=SecurableProto
2083        )
2084
2085        okc = OKCert()
2086        svr.certFactory = lambda: okc
2087
2088        return self.assertFailure(
2089            cli.callRemote(
2090                amp.StartTLS,
2091                tls_localCertificate=okc,
2092                tls_verifyAuthorities=[PretendRemoteCertificateAuthority()],
2093            ),
2094            RuntimeError,
2095        )
2096
2097    def test_messageReceivedError(self):
2098        """
2099        When a client with SSL enabled talks to a server without SSL, it
2100        should return a meaningful error.
2101        """
2102        svr = SecurableProto()
2103        okc = OKCert()
2104        svr.certFactory = lambda: okc
2105        box = amp.Box()
2106        box[b"_command"] = b"StartTLS"
2107        box[b"_ask"] = b"1"
2108        boxes = []
2109        svr.sendBox = boxes.append
2110        svr.makeConnection(StringTransport())
2111        svr.ampBoxReceived(box)
2112        self.assertEqual(
2113            boxes,
2114            [
2115                {
2116                    b"_error_code": b"TLS_ERROR",
2117                    b"_error": b"1",
2118                    b"_error_description": b"TLS not available",
2119                }
2120            ],
2121        )
2122
2123
2124class InheritedError(Exception):
2125    """
2126    This error is used to check inheritance.
2127    """
2128
2129
2130class OtherInheritedError(Exception):
2131    """
2132    This is a distinct error for checking inheritance.
2133    """
2134
2135
2136class BaseCommand(amp.Command):
2137    """
2138    This provides a command that will be subclassed.
2139    """
2140
2141    errors: Dict[Type[Exception], bytes] = {InheritedError: b"INHERITED_ERROR"}
2142
2143
2144class InheritedCommand(BaseCommand):
2145    """
2146    This is a command which subclasses another command but does not override
2147    anything.
2148    """
2149
2150
2151class AddErrorsCommand(BaseCommand):
2152    """
2153    This is a command which subclasses another command but adds errors to the
2154    list.
2155    """
2156
2157    arguments = [(b"other", amp.Boolean())]
2158    errors: Dict[Type[Exception], bytes] = {
2159        OtherInheritedError: b"OTHER_INHERITED_ERROR"
2160    }
2161
2162
2163class NormalCommandProtocol(amp.AMP):
2164    """
2165    This is a protocol which responds to L{BaseCommand}, and is used to test
2166    that inheritance does not interfere with the normal handling of errors.
2167    """
2168
2169    def resp(self):
2170        raise InheritedError()
2171
2172    BaseCommand.responder(resp)
2173
2174
2175class InheritedCommandProtocol(amp.AMP):
2176    """
2177    This is a protocol which responds to L{InheritedCommand}, and is used to
2178    test that inherited commands inherit their bases' errors if they do not
2179    respond to any of their own.
2180    """
2181
2182    def resp(self):
2183        raise InheritedError()
2184
2185    InheritedCommand.responder(resp)
2186
2187
2188class AddedCommandProtocol(amp.AMP):
2189    """
2190    This is a protocol which responds to L{AddErrorsCommand}, and is used to
2191    test that inherited commands can add their own new types of errors, but
2192    still respond in the same way to their parents types of errors.
2193    """
2194
2195    def resp(self, other):
2196        if other:
2197            raise OtherInheritedError()
2198        else:
2199            raise InheritedError()
2200
2201    AddErrorsCommand.responder(resp)
2202
2203
2204class CommandInheritanceTests(TestCase):
2205    """
2206    These tests verify that commands inherit error conditions properly.
2207    """
2208
2209    def errorCheck(self, err, proto, cmd, **kw):
2210        """
2211        Check that the appropriate kind of error is raised when a given command
2212        is sent to a given protocol.
2213        """
2214        c, s, p = connectedServerAndClient(ServerClass=proto, ClientClass=proto)
2215        d = c.callRemote(cmd, **kw)
2216        d2 = self.failUnlessFailure(d, err)
2217        p.flush()
2218        return d2
2219
2220    def test_basicErrorPropagation(self):
2221        """
2222        Verify that errors specified in a superclass are respected normally
2223        even if it has subclasses.
2224        """
2225        return self.errorCheck(InheritedError, NormalCommandProtocol, BaseCommand)
2226
2227    def test_inheritedErrorPropagation(self):
2228        """
2229        Verify that errors specified in a superclass command are propagated to
2230        its subclasses.
2231        """
2232        return self.errorCheck(
2233            InheritedError, InheritedCommandProtocol, InheritedCommand
2234        )
2235
2236    def test_inheritedErrorAddition(self):
2237        """
2238        Verify that new errors specified in a subclass of an existing command
2239        are honored even if the superclass defines some errors.
2240        """
2241        return self.errorCheck(
2242            OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=True
2243        )
2244
2245    def test_additionWithOriginalError(self):
2246        """
2247        Verify that errors specified in a command's superclass are respected
2248        even if that command defines new errors itself.
2249        """
2250        return self.errorCheck(
2251            InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False
2252        )
2253
2254
2255def _loseAndPass(err, proto):
2256    # be specific, pass on the error to the client.
2257    err.trap(error.ConnectionLost, error.ConnectionDone)
2258    del proto.connectionLost
2259    proto.connectionLost(err)
2260
2261
2262class LiveFireBase:
2263    """
2264    Utility for connected reactor-using tests.
2265    """
2266
2267    def setUp(self):
2268        """
2269        Create an amp server and connect a client to it.
2270        """
2271        from twisted.internet import reactor
2272
2273        self.serverFactory = protocol.ServerFactory()
2274        self.serverFactory.protocol = self.serverProto
2275        self.clientFactory = protocol.ClientFactory()
2276        self.clientFactory.protocol = self.clientProto
2277        self.clientFactory.onMade = defer.Deferred()
2278        self.serverFactory.onMade = defer.Deferred()
2279        self.serverPort = reactor.listenTCP(0, self.serverFactory)
2280        self.addCleanup(self.serverPort.stopListening)
2281        self.clientConn = reactor.connectTCP(
2282            "127.0.0.1", self.serverPort.getHost().port, self.clientFactory
2283        )
2284        self.addCleanup(self.clientConn.disconnect)
2285
2286        def getProtos(rlst):
2287            self.cli = self.clientFactory.theProto
2288            self.svr = self.serverFactory.theProto
2289
2290        dl = defer.DeferredList([self.clientFactory.onMade, self.serverFactory.onMade])
2291        return dl.addCallback(getProtos)
2292
2293    def tearDown(self):
2294        """
2295        Cleanup client and server connections, and check the error got at
2296        C{connectionLost}.
2297        """
2298        L = []
2299        for conn in self.cli, self.svr:
2300            if conn.transport is not None:
2301                # depend on amp's function connection-dropping behavior
2302                d = defer.Deferred().addErrback(_loseAndPass, conn)
2303                conn.connectionLost = d.errback
2304                conn.transport.loseConnection()
2305                L.append(d)
2306        return defer.gatherResults(L).addErrback(lambda first: first.value.subFailure)
2307
2308
2309def show(x):
2310    import sys
2311
2312    sys.stdout.write(x + "\n")
2313    sys.stdout.flush()
2314
2315
2316def tempSelfSigned():
2317    from twisted.internet import ssl
2318
2319    sharedDN = ssl.DN(CN="shared")
2320    key = ssl.KeyPair.generate()
2321    cr = key.certificateRequest(sharedDN)
2322    sscrd = key.signCertificateRequest(sharedDN, cr, lambda dn: True, 1234567)
2323    cert = key.newCertificate(sscrd)
2324    return cert
2325
2326
2327if ssl is not None:
2328    tempcert = tempSelfSigned()
2329
2330
2331@skipIf(skipSSL, "SSL not available")
2332@skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor")
2333class LiveFireTLSTests(LiveFireBase, TestCase):
2334
2335    clientProto = SecurableProto
2336    serverProto = SecurableProto
2337
2338    def test_liveFireCustomTLS(self):
2339        """
2340        Using real, live TLS, actually negotiate a connection.
2341
2342        This also looks at the 'peerCertificate' attribute's correctness, since
2343        that's actually loaded using OpenSSL calls, but the main purpose is to
2344        make sure that we didn't miss anything obvious in iosim about TLS
2345        negotiations.
2346        """
2347
2348        cert = tempcert
2349
2350        self.svr.verifyFactory = lambda: [cert]
2351        self.svr.certFactory = lambda: cert
2352        # only needed on the server, we specify the client below.
2353
2354        def secured(rslt):
2355            x = cert.digest()
2356
2357            def pinged(rslt2):
2358                # Interesting.  OpenSSL won't even _tell_ us about the peer
2359                # cert until we negotiate.  we should be able to do this in
2360                # 'secured' instead, but it looks like we can't.  I think this
2361                # is a bug somewhere far deeper than here.
2362                self.assertEqual(x, self.cli.hostCertificate.digest())
2363                self.assertEqual(x, self.cli.peerCertificate.digest())
2364                self.assertEqual(x, self.svr.hostCertificate.digest())
2365                self.assertEqual(x, self.svr.peerCertificate.digest())
2366
2367            return self.cli.callRemote(SecuredPing).addCallback(pinged)
2368
2369        return self.cli.callRemote(
2370            amp.StartTLS, tls_localCertificate=cert, tls_verifyAuthorities=[cert]
2371        ).addCallback(secured)
2372
2373
2374class SlightlySmartTLS(SimpleSymmetricCommandProtocol):
2375    """
2376    Specific implementation of server side protocol with different
2377    management of TLS.
2378    """
2379
2380    def getTLSVars(self):
2381        """
2382        @return: the global C{tempcert} certificate as local certificate.
2383        """
2384        return dict(tls_localCertificate=tempcert)
2385
2386    amp.StartTLS.responder(getTLSVars)
2387
2388
2389@skipIf(skipSSL, "SSL not available")
2390@skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor")
2391class PlainVanillaLiveFireTests(LiveFireBase, TestCase):
2392
2393    clientProto = SimpleSymmetricCommandProtocol
2394    serverProto = SimpleSymmetricCommandProtocol
2395
2396    def test_liveFireDefaultTLS(self):
2397        """
2398        Verify that out of the box, we can start TLS to at least encrypt the
2399        connection, even if we don't have any certificates to use.
2400        """
2401
2402        def secured(result):
2403            return self.cli.callRemote(SecuredPing)
2404
2405        return self.cli.callRemote(amp.StartTLS).addCallback(secured)
2406
2407
2408@skipIf(skipSSL, "SSL not available")
2409@skipIf(reactorLacksSSL, "This test case requires SSL support in the reactor")
2410class WithServerTLSVerificationTests(LiveFireBase, TestCase):
2411
2412    clientProto = SimpleSymmetricCommandProtocol
2413    serverProto = SlightlySmartTLS
2414
2415    def test_anonymousVerifyingClient(self):
2416        """
2417        Verify that anonymous clients can verify server certificates.
2418        """
2419
2420        def secured(result):
2421            return self.cli.callRemote(SecuredPing)
2422
2423        return self.cli.callRemote(
2424            amp.StartTLS, tls_verifyAuthorities=[tempcert]
2425        ).addCallback(secured)
2426
2427
2428class ProtocolIncludingArgument(amp.Argument):
2429    """
2430    An L{amp.Argument} which encodes its parser and serializer
2431    arguments *including the protocol* into its parsed and serialized
2432    forms.
2433    """
2434
2435    def fromStringProto(self, string, protocol):
2436        """
2437        Don't decode anything; just return all possible information.
2438
2439        @return: A two-tuple of the input string and the protocol.
2440        """
2441        return (string, protocol)
2442
2443    def toStringProto(self, obj, protocol):
2444        """
2445        Encode identifying information about L{object} and protocol
2446        into a string for later verification.
2447
2448        @type obj: L{object}
2449        @type protocol: L{amp.AMP}
2450        """
2451        ident = "%d:%d" % (id(obj), id(protocol))
2452        return ident.encode("ascii")
2453
2454
2455class ProtocolIncludingCommand(amp.Command):
2456    """
2457    A command that has argument and response schemas which use
2458    L{ProtocolIncludingArgument}.
2459    """
2460
2461    arguments = [(b"weird", ProtocolIncludingArgument())]
2462    response = [(b"weird", ProtocolIncludingArgument())]
2463
2464
2465class MagicSchemaCommand(amp.Command):
2466    """
2467    A command which overrides L{parseResponse}, L{parseArguments}, and
2468    L{makeResponse}.
2469    """
2470
2471    @classmethod
2472    def parseResponse(self, strings, protocol):
2473        """
2474        Don't do any parsing, just jam the input strings and protocol
2475        onto the C{protocol.parseResponseArguments} attribute as a
2476        two-tuple. Return the original strings.
2477        """
2478        protocol.parseResponseArguments = (strings, protocol)
2479        return strings
2480
2481    @classmethod
2482    def parseArguments(cls, strings, protocol):
2483        """
2484        Don't do any parsing, just jam the input strings and protocol
2485        onto the C{protocol.parseArgumentsArguments} attribute as a
2486        two-tuple. Return the original strings.
2487        """
2488        protocol.parseArgumentsArguments = (strings, protocol)
2489        return strings
2490
2491    @classmethod
2492    def makeArguments(cls, objects, protocol):
2493        """
2494        Don't do any serializing, just jam the input strings and protocol
2495        onto the C{protocol.makeArgumentsArguments} attribute as a
2496        two-tuple. Return the original strings.
2497        """
2498        protocol.makeArgumentsArguments = (objects, protocol)
2499        return objects
2500
2501
2502class NoNetworkProtocol(amp.AMP):
2503    """
2504    An L{amp.AMP} subclass which overrides private methods to avoid
2505    testing the network. It also provides a responder for
2506    L{MagicSchemaCommand} that does nothing, so that tests can test
2507    aspects of the interaction of L{amp.Command}s and L{amp.AMP}.
2508
2509    @ivar parseArgumentsArguments: Arguments that have been passed to any
2510    L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
2511    this protocol.
2512
2513    @ivar parseResponseArguments: Responses that have been returned from a
2514    L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
2515    this protocol.
2516
2517    @ivar makeArgumentsArguments: Arguments that have been serialized by any
2518    L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
2519    this protocol.
2520    """
2521
2522    def _sendBoxCommand(self, commandName, strings, requiresAnswer):
2523        """
2524        Return a Deferred which fires with the original strings.
2525        """
2526        return defer.succeed(strings)
2527
2528    MagicSchemaCommand.responder(lambda s, weird: {})
2529
2530
2531class MyBox(dict):
2532    """
2533    A unique dict subclass.
2534    """
2535
2536
2537class ProtocolIncludingCommandWithDifferentCommandType(ProtocolIncludingCommand):
2538    """
2539    A L{ProtocolIncludingCommand} subclass whose commandType is L{MyBox}
2540    """
2541
2542    commandType = MyBox  # type: ignore[assignment]
2543
2544
2545class CommandTests(TestCase):
2546    """
2547    Tests for L{amp.Argument} and L{amp.Command}.
2548    """
2549
2550    def test_argumentInterface(self):
2551        """
2552        L{Argument} instances provide L{amp.IArgumentType}.
2553        """
2554        self.assertTrue(verifyObject(amp.IArgumentType, amp.Argument()))
2555
2556    def test_parseResponse(self):
2557        """
2558        There should be a class method of Command which accepts a
2559        mapping of argument names to serialized forms and returns a
2560        similar mapping whose values have been parsed via the
2561        Command's response schema.
2562        """
2563        protocol = object()
2564        result = b"whatever"
2565        strings = {b"weird": result}
2566        self.assertEqual(
2567            ProtocolIncludingCommand.parseResponse(strings, protocol),
2568            {"weird": (result, protocol)},
2569        )
2570
2571    def test_callRemoteCallsParseResponse(self):
2572        """
2573        Making a remote call on a L{amp.Command} subclass which
2574        overrides the C{parseResponse} method should call that
2575        C{parseResponse} method to get the response.
2576        """
2577        client = NoNetworkProtocol()
2578        thingy = b"weeoo"
2579        response = client.callRemote(MagicSchemaCommand, weird=thingy)
2580
2581        def gotResponse(ign):
2582            self.assertEqual(client.parseResponseArguments, ({"weird": thingy}, client))
2583
2584        response.addCallback(gotResponse)
2585        return response
2586
2587    def test_parseArguments(self):
2588        """
2589        There should be a class method of L{amp.Command} which accepts
2590        a mapping of argument names to serialized forms and returns a
2591        similar mapping whose values have been parsed via the
2592        command's argument schema.
2593        """
2594        protocol = object()
2595        result = b"whatever"
2596        strings = {b"weird": result}
2597        self.assertEqual(
2598            ProtocolIncludingCommand.parseArguments(strings, protocol),
2599            {"weird": (result, protocol)},
2600        )
2601
2602    def test_responderCallsParseArguments(self):
2603        """
2604        Making a remote call on a L{amp.Command} subclass which
2605        overrides the C{parseArguments} method should call that
2606        C{parseArguments} method to get the arguments.
2607        """
2608        protocol = NoNetworkProtocol()
2609        responder = protocol.locateResponder(MagicSchemaCommand.commandName)
2610        argument = object()
2611        response = responder(dict(weird=argument))
2612        response.addCallback(
2613            lambda ign: self.assertEqual(
2614                protocol.parseArgumentsArguments, ({"weird": argument}, protocol)
2615            )
2616        )
2617        return response
2618
2619    def test_makeArguments(self):
2620        """
2621        There should be a class method of L{amp.Command} which accepts
2622        a mapping of argument names to objects and returns a similar
2623        mapping whose values have been serialized via the command's
2624        argument schema.
2625        """
2626        protocol = object()
2627        argument = object()
2628        objects = {"weird": argument}
2629        ident = "%d:%d" % (id(argument), id(protocol))
2630        self.assertEqual(
2631            ProtocolIncludingCommand.makeArguments(objects, protocol),
2632            {b"weird": ident.encode("ascii")},
2633        )
2634
2635    def test_makeArgumentsUsesCommandType(self):
2636        """
2637        L{amp.Command.makeArguments}'s return type should be the type
2638        of the result of L{amp.Command.commandType}.
2639        """
2640        protocol = object()
2641        objects = {"weird": b"whatever"}
2642
2643        result = ProtocolIncludingCommandWithDifferentCommandType.makeArguments(
2644            objects, protocol
2645        )
2646        self.assertIs(type(result), MyBox)
2647
2648    def test_callRemoteCallsMakeArguments(self):
2649        """
2650        Making a remote call on a L{amp.Command} subclass which
2651        overrides the C{makeArguments} method should call that
2652        C{makeArguments} method to get the response.
2653        """
2654        client = NoNetworkProtocol()
2655        argument = object()
2656        response = client.callRemote(MagicSchemaCommand, weird=argument)
2657
2658        def gotResponse(ign):
2659            self.assertEqual(
2660                client.makeArgumentsArguments, ({"weird": argument}, client)
2661            )
2662
2663        response.addCallback(gotResponse)
2664        return response
2665
2666    def test_extraArgumentsDisallowed(self):
2667        """
2668        L{Command.makeArguments} raises L{amp.InvalidSignature} if the objects
2669        dictionary passed to it includes a key which does not correspond to the
2670        Python identifier for a defined argument.
2671        """
2672        self.assertRaises(
2673            amp.InvalidSignature,
2674            Hello.makeArguments,
2675            dict(hello="hello", bogusArgument=object()),
2676            None,
2677        )
2678
2679    def test_wireSpellingDisallowed(self):
2680        """
2681        If a command argument conflicts with a Python keyword, the
2682        untransformed argument name is not allowed as a key in the dictionary
2683        passed to L{Command.makeArguments}.  If it is supplied,
2684        L{amp.InvalidSignature} is raised.
2685
2686        This may be a pointless implementation restriction which may be lifted.
2687        The current behavior is tested to verify that such arguments are not
2688        silently dropped on the floor (the previous behavior).
2689        """
2690        self.assertRaises(
2691            amp.InvalidSignature,
2692            Hello.makeArguments,
2693            dict(hello="required", **{"print": "print value"}),
2694            None,
2695        )
2696
2697    def test_commandNameDefaultsToClassNameAsByteString(self):
2698        """
2699        A L{Command} subclass without a defined C{commandName} that's
2700        not a byte string.
2701        """
2702
2703        class NewCommand(amp.Command):
2704            """
2705            A new command.
2706            """
2707
2708        self.assertEqual(b"NewCommand", NewCommand.commandName)
2709
2710    def test_commandNameMustBeAByteString(self):
2711        """
2712        A L{Command} subclass cannot be defined with a C{commandName} that's
2713        not a byte string.
2714        """
2715        error = self.assertRaises(
2716            TypeError, type, "NewCommand", (amp.Command,), {"commandName": "FOO"}
2717        )
2718        self.assertRegex(
2719            str(error), "^Command names must be byte strings, got: u?'FOO'$"
2720        )
2721
2722    def test_commandArgumentsMustBeNamedWithByteStrings(self):
2723        """
2724        A L{Command} subclass's C{arguments} must have byte string names.
2725        """
2726        error = self.assertRaises(
2727            TypeError,
2728            type,
2729            "NewCommand",
2730            (amp.Command,),
2731            {"arguments": [("foo", None)]},
2732        )
2733        self.assertRegex(
2734            str(error), "^Argument names must be byte strings, got: u?'foo'$"
2735        )
2736
2737    def test_commandResponseMustBeNamedWithByteStrings(self):
2738        """
2739        A L{Command} subclass's C{response} must have byte string names.
2740        """
2741        error = self.assertRaises(
2742            TypeError, type, "NewCommand", (amp.Command,), {"response": [("foo", None)]}
2743        )
2744        self.assertRegex(
2745            str(error), "^Response names must be byte strings, got: u?'foo'$"
2746        )
2747
2748    def test_commandErrorsIsConvertedToDict(self):
2749        """
2750        A L{Command} subclass's C{errors} is coerced into a C{dict}.
2751        """
2752
2753        class NewCommand(amp.Command):
2754            errors = [(ZeroDivisionError, b"ZDE")]
2755
2756        self.assertEqual({ZeroDivisionError: b"ZDE"}, NewCommand.errors)
2757
2758    def test_commandErrorsMustUseBytesForOnWireRepresentation(self):
2759        """
2760        A L{Command} subclass's C{errors} must map exceptions to byte strings.
2761        """
2762        error = self.assertRaises(
2763            TypeError,
2764            type,
2765            "NewCommand",
2766            (amp.Command,),
2767            {"errors": [(ZeroDivisionError, "foo")]},
2768        )
2769        self.assertRegex(str(error), "^Error names must be byte strings, got: u?'foo'$")
2770
2771    def test_commandFatalErrorsIsConvertedToDict(self):
2772        """
2773        A L{Command} subclass's C{fatalErrors} is coerced into a C{dict}.
2774        """
2775
2776        class NewCommand(amp.Command):
2777            fatalErrors = [(ZeroDivisionError, b"ZDE")]
2778
2779        self.assertEqual({ZeroDivisionError: b"ZDE"}, NewCommand.fatalErrors)
2780
2781    def test_commandFatalErrorsMustUseBytesForOnWireRepresentation(self):
2782        """
2783        A L{Command} subclass's C{fatalErrors} must map exceptions to byte
2784        strings.
2785        """
2786        error = self.assertRaises(
2787            TypeError,
2788            type,
2789            "NewCommand",
2790            (amp.Command,),
2791            {"fatalErrors": [(ZeroDivisionError, "foo")]},
2792        )
2793        self.assertRegex(
2794            str(error), "^Fatal error names must be byte strings, " "got: u?'foo'$"
2795        )
2796
2797
2798class ListOfTestsMixin:
2799    """
2800    Base class for testing L{ListOf}, a parameterized zero-or-more argument
2801    type.
2802
2803    @ivar elementType: Subclasses should set this to an L{Argument}
2804        instance.  The tests will make a L{ListOf} using this.
2805
2806    @ivar strings: Subclasses should set this to a dictionary mapping some
2807        number of keys -- as BYTE strings -- to the correct serialized form
2808        for some example values. These should agree with what L{elementType}
2809        produces/accepts.
2810
2811    @ivar objects: Subclasses should set this to a dictionary with the same
2812        keys as C{strings} -- as NATIVE strings -- and with values which are
2813        the lists which should serialize to the values in the C{strings}
2814        dictionary.
2815    """
2816
2817    def test_toBox(self):
2818        """
2819        L{ListOf.toBox} extracts the list of objects from the C{objects}
2820        dictionary passed to it, using the C{name} key also passed to it,
2821        serializes each of the elements in that list using the L{Argument}
2822        instance previously passed to its initializer, combines the serialized
2823        results, and inserts the result into the C{strings} dictionary using
2824        the same C{name} key.
2825        """
2826        stringList = amp.ListOf(self.elementType)
2827        strings = amp.AmpBox()
2828        for key in self.objects:
2829            stringList.toBox(key.encode("ascii"), strings, self.objects.copy(), None)
2830        self.assertEqual(strings, self.strings)
2831
2832    def test_fromBox(self):
2833        """
2834        L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}.
2835        """
2836        stringList = amp.ListOf(self.elementType)
2837        objects = {}
2838        for key in self.strings:
2839            stringList.fromBox(key, self.strings.copy(), objects, None)
2840        self.assertEqual(objects, self.objects)
2841
2842
2843class ListOfStringsTests(TestCase, ListOfTestsMixin):
2844    """
2845    Tests for L{ListOf} combined with L{amp.String}.
2846    """
2847
2848    elementType = amp.String()
2849
2850    strings = {
2851        b"empty": b"",
2852        b"single": b"\x00\x03foo",
2853        b"multiple": b"\x00\x03bar\x00\x03baz\x00\x04quux",
2854    }
2855
2856    objects = {"empty": [], "single": [b"foo"], "multiple": [b"bar", b"baz", b"quux"]}
2857
2858
2859class ListOfIntegersTests(TestCase, ListOfTestsMixin):
2860    """
2861    Tests for L{ListOf} combined with L{amp.Integer}.
2862    """
2863
2864    elementType = amp.Integer()
2865
2866    huge = (
2867        9999999999999999999999999999999999999999999999999999999999
2868        * 9999999999999999999999999999999999999999999999999999999999
2869    )
2870
2871    strings = {
2872        b"empty": b"",
2873        b"single": b"\x00\x0210",
2874        b"multiple": b"\x00\x011\x00\x0220\x00\x03500",
2875        b"huge": b"\x00\x74%d" % (huge,),
2876        b"negative": b"\x00\x02-1",
2877    }
2878
2879    objects = {
2880        "empty": [],
2881        "single": [10],
2882        "multiple": [1, 20, 500],
2883        "huge": [huge],
2884        "negative": [-1],
2885    }
2886
2887
2888class ListOfUnicodeTests(TestCase, ListOfTestsMixin):
2889    """
2890    Tests for L{ListOf} combined with L{amp.Unicode}.
2891    """
2892
2893    elementType = amp.Unicode()
2894
2895    strings = {
2896        b"empty": b"",
2897        b"single": b"\x00\x03foo",
2898        b"multiple": b"\x00\x03\xe2\x98\x83\x00\x05Hello\x00\x05world",
2899    }
2900
2901    objects = {
2902        "empty": [],
2903        "single": ["foo"],
2904        "multiple": ["\N{SNOWMAN}", "Hello", "world"],
2905    }
2906
2907
2908class ListOfDecimalTests(TestCase, ListOfTestsMixin):
2909    """
2910    Tests for L{ListOf} combined with L{amp.Decimal}.
2911    """
2912
2913    elementType = amp.Decimal()
2914
2915    strings = {
2916        b"empty": b"",
2917        b"single": b"\x00\x031.1",
2918        b"extreme": b"\x00\x08Infinity\x00\x09-Infinity",
2919        b"scientist": b"\x00\x083.141E+5\x00\x0a0.00003141\x00\x083.141E-7"
2920        b"\x00\x09-3.141E+5\x00\x0b-0.00003141\x00\x09-3.141E-7",
2921        b"engineer": (
2922            b"\x00\x04"
2923            + decimal.Decimal("0e6").to_eng_string().encode("ascii")
2924            + b"\x00\x06"
2925            + decimal.Decimal("1.5E-9").to_eng_string().encode("ascii")
2926        ),
2927    }
2928
2929    objects = {
2930        "empty": [],
2931        "single": [decimal.Decimal("1.1")],
2932        "extreme": [
2933            decimal.Decimal("Infinity"),
2934            decimal.Decimal("-Infinity"),
2935        ],
2936        # exarkun objected to AMP supporting engineering notation because
2937        # it was redundant, until we realised that 1E6 has less precision
2938        # than 1000000 and is represented differently.  But they compare
2939        # and even hash equally.  There were tears.
2940        "scientist": [
2941            decimal.Decimal("3.141E5"),
2942            decimal.Decimal("3.141e-5"),
2943            decimal.Decimal("3.141E-7"),
2944            decimal.Decimal("-3.141e5"),
2945            decimal.Decimal("-3.141E-5"),
2946            decimal.Decimal("-3.141e-7"),
2947        ],
2948        "engineer": [
2949            decimal.Decimal("0e6"),
2950            decimal.Decimal("1.5E-9"),
2951        ],
2952    }
2953
2954
2955class ListOfDecimalNanTests(TestCase, ListOfTestsMixin):
2956    """
2957    Tests for L{ListOf} combined with L{amp.Decimal} for not-a-number values.
2958    """
2959
2960    elementType = amp.Decimal()
2961
2962    strings = {
2963        b"nan": b"\x00\x03NaN\x00\x04-NaN\x00\x04sNaN\x00\x05-sNaN",
2964    }
2965
2966    objects = {
2967        "nan": [
2968            decimal.Decimal("NaN"),
2969            decimal.Decimal("-NaN"),
2970            decimal.Decimal("sNaN"),
2971            decimal.Decimal("-sNaN"),
2972        ]
2973    }
2974
2975    def test_fromBox(self):
2976        """
2977        L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}.
2978        """
2979        # Helpers.  Decimal.is_{qnan,snan,signed}() are new in 2.6 (or 2.5.2,
2980        # but who's counting).
2981        def is_qnan(decimal):
2982            return "NaN" in str(decimal) and "sNaN" not in str(decimal)
2983
2984        def is_snan(decimal):
2985            return "sNaN" in str(decimal)
2986
2987        def is_signed(decimal):
2988            return "-" in str(decimal)
2989
2990        # NaN values have unusual equality semantics, so this method is
2991        # overridden to compare the resulting objects in a way which works with
2992        # NaNs.
2993        stringList = amp.ListOf(self.elementType)
2994        objects = {}
2995        for key in self.strings:
2996            stringList.fromBox(key, self.strings.copy(), objects, None)
2997        n = objects["nan"]
2998        self.assertTrue(is_qnan(n[0]) and not is_signed(n[0]))
2999        self.assertTrue(is_qnan(n[1]) and is_signed(n[1]))
3000        self.assertTrue(is_snan(n[2]) and not is_signed(n[2]))
3001        self.assertTrue(is_snan(n[3]) and is_signed(n[3]))
3002
3003
3004class DecimalTests(TestCase):
3005    """
3006    Tests for L{amp.Decimal}.
3007    """
3008
3009    def test_nonDecimal(self):
3010        """
3011        L{amp.Decimal.toString} raises L{ValueError} if passed an object which
3012        is not an instance of C{decimal.Decimal}.
3013        """
3014        argument = amp.Decimal()
3015        self.assertRaises(ValueError, argument.toString, "1.234")
3016        self.assertRaises(ValueError, argument.toString, 1.234)
3017        self.assertRaises(ValueError, argument.toString, 1234)
3018
3019
3020class FloatTests(TestCase):
3021    """
3022    Tests for L{amp.Float}.
3023    """
3024
3025    def test_nonFloat(self):
3026        """
3027        L{amp.Float.toString} raises L{ValueError} if passed an object which
3028        is not a L{float}.
3029        """
3030        argument = amp.Float()
3031        self.assertRaises(ValueError, argument.toString, "1.234")
3032        self.assertRaises(ValueError, argument.toString, b"1.234")
3033        self.assertRaises(ValueError, argument.toString, 1234)
3034
3035    def test_float(self):
3036        """
3037        L{amp.Float.toString} returns a bytestring when it is given a L{float}.
3038        """
3039        argument = amp.Float()
3040        self.assertEqual(argument.toString(1.234), b"1.234")
3041
3042
3043class ListOfDateTimeTests(TestCase, ListOfTestsMixin):
3044    """
3045    Tests for L{ListOf} combined with L{amp.DateTime}.
3046    """
3047
3048    elementType = amp.DateTime()
3049
3050    strings = {
3051        b"christmas": b"\x00\x202010-12-25T00:00:00.000000-00:00"
3052        b"\x00\x202010-12-25T00:00:00.000000-00:00",
3053        b"christmas in eu": b"\x00\x202010-12-25T00:00:00.000000+01:00",
3054        b"christmas in iran": b"\x00\x202010-12-25T00:00:00.000000+03:30",
3055        b"christmas in nyc": b"\x00\x202010-12-25T00:00:00.000000-05:00",
3056        b"previous tests": b"\x00\x202010-12-25T00:00:00.000000+03:19"
3057        b"\x00\x202010-12-25T00:00:00.000000-06:59",
3058    }
3059
3060    objects = {
3061        "christmas": [
3062            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=amp.utc),
3063            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 0, 0)),
3064        ],
3065        "christmas in eu": [
3066            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 1, 0)),
3067        ],
3068        "christmas in iran": [
3069            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 3, 30)),
3070        ],
3071        "christmas in nyc": [
3072            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("-", 5, 0)),
3073        ],
3074        "previous tests": [
3075            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("+", 3, 19)),
3076            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=tz("-", 6, 59)),
3077        ],
3078    }
3079
3080
3081class ListOfOptionalTests(TestCase):
3082    """
3083    Tests to ensure L{ListOf} AMP arguments can be omitted from AMP commands
3084    via the 'optional' flag.
3085    """
3086
3087    def test_requiredArgumentWithNoneValueRaisesTypeError(self):
3088        """
3089        L{ListOf.toBox} raises C{TypeError} when passed a value of L{None}
3090        for the argument.
3091        """
3092        stringList = amp.ListOf(amp.Integer())
3093        self.assertRaises(
3094            TypeError,
3095            stringList.toBox,
3096            b"omitted",
3097            amp.AmpBox(),
3098            {"omitted": None},
3099            None,
3100        )
3101
3102    def test_optionalArgumentWithNoneValueOmitted(self):
3103        """
3104        L{ListOf.toBox} silently omits serializing any argument with a
3105        value of L{None} that is designated as optional for the protocol.
3106        """
3107        stringList = amp.ListOf(amp.Integer(), optional=True)
3108        strings = amp.AmpBox()
3109        stringList.toBox(b"omitted", strings, {b"omitted": None}, None)
3110        self.assertEqual(strings, {})
3111
3112    def test_requiredArgumentWithKeyMissingRaisesKeyError(self):
3113        """
3114        L{ListOf.toBox} raises C{KeyError} if the argument's key is not
3115        present in the objects dictionary.
3116        """
3117        stringList = amp.ListOf(amp.Integer())
3118        self.assertRaises(
3119            KeyError,
3120            stringList.toBox,
3121            b"ommited",
3122            amp.AmpBox(),
3123            {"someOtherKey": 0},
3124            None,
3125        )
3126
3127    def test_optionalArgumentWithKeyMissingOmitted(self):
3128        """
3129        L{ListOf.toBox} silently omits serializing any argument designated
3130        as optional whose key is not present in the objects dictionary.
3131        """
3132        stringList = amp.ListOf(amp.Integer(), optional=True)
3133        stringList.toBox(b"ommited", amp.AmpBox(), {b"someOtherKey": 0}, None)
3134
3135    def test_omittedOptionalArgumentDeserializesAsNone(self):
3136        """
3137        L{ListOf.fromBox} correctly reverses the operation performed by
3138        L{ListOf.toBox} for optional arguments.
3139        """
3140        stringList = amp.ListOf(amp.Integer(), optional=True)
3141        objects = {}
3142        stringList.fromBox(b"omitted", {}, objects, None)
3143        self.assertEqual(objects, {"omitted": None})
3144
3145
3146@implementer(interfaces.IUNIXTransport)
3147class UNIXStringTransport:
3148    """
3149    An in-memory implementation of L{interfaces.IUNIXTransport} which collects
3150    all data given to it for later inspection.
3151
3152    @ivar _queue: A C{list} of the data which has been given to this transport,
3153        eg via C{write} or C{sendFileDescriptor}.  Elements are two-tuples of a
3154        string (identifying the destination of the data) and the data itself.
3155    """
3156
3157    def __init__(self, descriptorFuzz):
3158        """
3159        @param descriptorFuzz: An offset to apply to descriptors.
3160        @type descriptorFuzz: C{int}
3161        """
3162        self._fuzz = descriptorFuzz
3163        self._queue = []
3164
3165    def sendFileDescriptor(self, descriptor):
3166        self._queue.append(("fileDescriptorReceived", descriptor + self._fuzz))
3167
3168    def write(self, data):
3169        self._queue.append(("dataReceived", data))
3170
3171    def writeSequence(self, seq):
3172        for data in seq:
3173            self.write(data)
3174
3175    def loseConnection(self):
3176        self._queue.append(("connectionLost", Failure(error.ConnectionLost())))
3177
3178    def getHost(self):
3179        return address.UNIXAddress("/tmp/some-path")
3180
3181    def getPeer(self):
3182        return address.UNIXAddress("/tmp/another-path")
3183
3184
3185# Minimal evidence that we got the signatures right
3186verifyClass(interfaces.ITransport, UNIXStringTransport)
3187verifyClass(interfaces.IUNIXTransport, UNIXStringTransport)
3188
3189
3190class DescriptorTests(TestCase):
3191    """
3192    Tests for L{amp.Descriptor}, an argument type for passing a file descriptor
3193    over an AMP connection over a UNIX domain socket.
3194    """
3195
3196    def setUp(self):
3197        self.fuzz = 3
3198        self.transport = UNIXStringTransport(descriptorFuzz=self.fuzz)
3199        self.protocol = amp.BinaryBoxProtocol(amp.BoxDispatcher(amp.CommandLocator()))
3200        self.protocol.makeConnection(self.transport)
3201
3202    def test_fromStringProto(self):
3203        """
3204        L{Descriptor.fromStringProto} constructs a file descriptor value by
3205        extracting a previously received file descriptor corresponding to the
3206        wire value of the argument from the L{_DescriptorExchanger} state of the
3207        protocol passed to it.
3208
3209        This is a whitebox test which involves direct L{_DescriptorExchanger}
3210        state inspection.
3211        """
3212        argument = amp.Descriptor()
3213        self.protocol.fileDescriptorReceived(5)
3214        self.protocol.fileDescriptorReceived(3)
3215        self.protocol.fileDescriptorReceived(1)
3216        self.assertEqual(5, argument.fromStringProto("0", self.protocol))
3217        self.assertEqual(3, argument.fromStringProto("1", self.protocol))
3218        self.assertEqual(1, argument.fromStringProto("2", self.protocol))
3219        self.assertEqual({}, self.protocol._descriptors)
3220
3221    def test_toStringProto(self):
3222        """
3223        To send a file descriptor, L{Descriptor.toStringProto} uses the
3224        L{IUNIXTransport.sendFileDescriptor} implementation of the transport of
3225        the protocol passed to it to copy the file descriptor.  Each subsequent
3226        descriptor sent over a particular AMP connection is assigned the next
3227        integer value, starting from 0.  The base ten string representation of
3228        this value is the byte encoding of the argument.
3229
3230        This is a whitebox test which involves direct L{_DescriptorExchanger}
3231        state inspection and mutation.
3232        """
3233        argument = amp.Descriptor()
3234        self.assertEqual(b"0", argument.toStringProto(2, self.protocol))
3235        self.assertEqual(
3236            ("fileDescriptorReceived", 2 + self.fuzz), self.transport._queue.pop(0)
3237        )
3238        self.assertEqual(b"1", argument.toStringProto(4, self.protocol))
3239        self.assertEqual(
3240            ("fileDescriptorReceived", 4 + self.fuzz), self.transport._queue.pop(0)
3241        )
3242        self.assertEqual(b"2", argument.toStringProto(6, self.protocol))
3243        self.assertEqual(
3244            ("fileDescriptorReceived", 6 + self.fuzz), self.transport._queue.pop(0)
3245        )
3246        self.assertEqual({}, self.protocol._descriptors)
3247
3248    def test_roundTrip(self):
3249        """
3250        L{amp.Descriptor.fromBox} can interpret an L{amp.AmpBox} constructed by
3251        L{amp.Descriptor.toBox} to reconstruct a file descriptor value.
3252        """
3253        name = "alpha"
3254        nameAsBytes = name.encode("ascii")
3255        strings = {}
3256        descriptor = 17
3257        sendObjects = {name: descriptor}
3258
3259        argument = amp.Descriptor()
3260        argument.toBox(nameAsBytes, strings, sendObjects.copy(), self.protocol)
3261
3262        receiver = amp.BinaryBoxProtocol(amp.BoxDispatcher(amp.CommandLocator()))
3263        for event in self.transport._queue:
3264            getattr(receiver, event[0])(*event[1:])
3265
3266        receiveObjects = {}
3267        argument.fromBox(nameAsBytes, strings.copy(), receiveObjects, receiver)
3268
3269        # Make sure we got the descriptor.  Adjust by fuzz to be more convincing
3270        # of having gone through L{IUNIXTransport.sendFileDescriptor}, not just
3271        # converted to a string and then parsed back into an integer.
3272        self.assertEqual(descriptor + self.fuzz, receiveObjects[name])
3273
3274
3275class DateTimeTests(TestCase):
3276    """
3277    Tests for L{amp.DateTime}, L{amp._FixedOffsetTZInfo}, and L{amp.utc}.
3278    """
3279
3280    string = b"9876-01-23T12:34:56.054321-01:23"
3281    tzinfo = tz("-", 1, 23)
3282    object = datetime.datetime(9876, 1, 23, 12, 34, 56, 54321, tzinfo)
3283
3284    def test_invalidString(self):
3285        """
3286        L{amp.DateTime.fromString} raises L{ValueError} when passed a string
3287        which does not represent a timestamp in the proper format.
3288        """
3289        d = amp.DateTime()
3290        self.assertRaises(ValueError, d.fromString, "abc")
3291
3292    def test_invalidDatetime(self):
3293        """
3294        L{amp.DateTime.toString} raises L{ValueError} when passed a naive
3295        datetime (a datetime with no timezone information).
3296        """
3297        d = amp.DateTime()
3298        self.assertRaises(
3299            ValueError, d.toString, datetime.datetime(2010, 12, 25, 0, 0, 0)
3300        )
3301
3302    def test_fromString(self):
3303        """
3304        L{amp.DateTime.fromString} returns a C{datetime.datetime} with all of
3305        its fields populated from the string passed to it.
3306        """
3307        argument = amp.DateTime()
3308        value = argument.fromString(self.string)
3309        self.assertEqual(value, self.object)
3310
3311    def test_toString(self):
3312        """
3313        L{amp.DateTime.toString} returns a C{str} in the wire format including
3314        all of the information from the C{datetime.datetime} passed into it,
3315        including the timezone offset.
3316        """
3317        argument = amp.DateTime()
3318        value = argument.toString(self.object)
3319        self.assertEqual(value, self.string)
3320
3321
3322class UTCTests(TestCase):
3323    """
3324    Tests for L{amp.utc}.
3325    """
3326
3327    def test_tzname(self):
3328        """
3329        L{amp.utc.tzname} returns C{"+00:00"}.
3330        """
3331        self.assertEqual(amp.utc.tzname(None), "+00:00")
3332
3333    def test_dst(self):
3334        """
3335        L{amp.utc.dst} returns a zero timedelta.
3336        """
3337        self.assertEqual(amp.utc.dst(None), datetime.timedelta(0))
3338
3339    def test_utcoffset(self):
3340        """
3341        L{amp.utc.utcoffset} returns a zero timedelta.
3342        """
3343        self.assertEqual(amp.utc.utcoffset(None), datetime.timedelta(0))
3344
3345    def test_badSign(self):
3346        """
3347        L{amp._FixedOffsetTZInfo.fromSignHoursMinutes} raises L{ValueError} if
3348        passed an offset sign other than C{'+'} or C{'-'}.
3349        """
3350        self.assertRaises(ValueError, tz, "?", 0, 0)
3351
3352
3353class RemoteAmpErrorTests(TestCase):
3354    """
3355    Tests for L{amp.RemoteAmpError}.
3356    """
3357
3358    def test_stringMessage(self):
3359        """
3360        L{amp.RemoteAmpError} renders the given C{errorCode} (C{bytes}) and
3361        C{description} into a native string.
3362        """
3363        error = amp.RemoteAmpError(b"BROKEN", "Something has broken")
3364        self.assertEqual("Code<BROKEN>: Something has broken", str(error))
3365
3366    def test_stringMessageReplacesNonAsciiText(self):
3367        """
3368        When C{errorCode} contains non-ASCII characters, L{amp.RemoteAmpError}
3369        renders then as backslash-escape sequences.
3370        """
3371        error = amp.RemoteAmpError(b"BROKEN-\xff", "Something has broken")
3372        self.assertEqual("Code<BROKEN-\\xff>: Something has broken", str(error))
3373
3374    def test_stringMessageWithLocalFailure(self):
3375        """
3376        L{amp.RemoteAmpError} renders local errors with a "(local)" marker and
3377        a brief traceback.
3378        """
3379        failure = Failure(Exception("Something came loose"))
3380        error = amp.RemoteAmpError(b"BROKEN", "Something has broken", local=failure)
3381        self.assertRegex(
3382            str(error),
3383            (
3384                "^Code<BROKEN> [(]local[)]: Something has broken\n"
3385                "Traceback [(]failure with no frames[)]: "
3386                "<.+Exception.>: Something came loose\n"
3387            ),
3388        )
3389