1# Copyright (c) 2005 Divmod, Inc.
2# Copyright (c) Twisted Matrix Laboratories.
3# See LICENSE for details.
4
5"""
6Tests for L{twisted.protocols.amp}.
7"""
8
9import datetime
10import decimal
11
12from zope.interface import implements
13from zope.interface.verify import verifyClass, verifyObject
14
15from twisted.python import filepath
16from twisted.python.failure import Failure
17from twisted.protocols import amp
18from twisted.trial import unittest
19from twisted.internet import protocol, defer, error, reactor, interfaces
20from twisted.test import iosim
21from twisted.test.proto_helpers import StringTransport
22
23ssl = None
24try:
25    from twisted.internet import ssl
26except ImportError:
27    pass
28
29if ssl and not ssl.supported:
30    ssl = None
31
32if ssl is None:
33    skipSSL = "SSL not available"
34else:
35    skipSSL = None
36
37
38class TestProto(protocol.Protocol):
39    """
40    A trivial protocol for use in testing where a L{Protocol} is expected.
41
42    @ivar instanceId: the id of this instance
43    @ivar onConnLost: deferred that will fired when the connection is lost
44    @ivar dataToSend: data to send on the protocol
45    """
46
47    instanceCount = 0
48
49    def __init__(self, onConnLost, dataToSend):
50        self.onConnLost = onConnLost
51        self.dataToSend = dataToSend
52        self.instanceId = TestProto.instanceCount
53        TestProto.instanceCount = TestProto.instanceCount + 1
54
55
56    def connectionMade(self):
57        self.data = []
58        self.transport.write(self.dataToSend)
59
60
61    def dataReceived(self, bytes):
62        self.data.append(bytes)
63
64
65    def connectionLost(self, reason):
66        self.onConnLost.callback(self.data)
67
68
69    def __repr__(self):
70        """
71        Custom repr for testing to avoid coupling amp tests with repr from
72        L{Protocol}
73
74        Returns a string which contains a unique identifier that can be looked
75        up using the instanceId property::
76
77            <TestProto #3>
78        """
79        return "<TestProto #%d>" % (self.instanceId,)
80
81
82
83class SimpleSymmetricProtocol(amp.AMP):
84
85    def sendHello(self, text):
86        return self.callRemoteString(
87            "hello",
88            hello=text)
89
90    def amp_HELLO(self, box):
91        return amp.Box(hello=box['hello'])
92
93    def amp_HOWDOYOUDO(self, box):
94        return amp.QuitBox(howdoyoudo='world')
95
96
97
98class UnfriendlyGreeting(Exception):
99    """Greeting was insufficiently kind.
100    """
101
102class DeathThreat(Exception):
103    """Greeting was insufficiently kind.
104    """
105
106class UnknownProtocol(Exception):
107    """Asked to switch to the wrong protocol.
108    """
109
110
111class TransportPeer(amp.Argument):
112    # this serves as some informal documentation for how to get variables from
113    # the protocol or your environment and pass them to methods as arguments.
114    def retrieve(self, d, name, proto):
115        return ''
116
117    def fromStringProto(self, notAString, proto):
118        return proto.transport.getPeer()
119
120    def toBox(self, name, strings, objects, proto):
121        return
122
123
124
125class Hello(amp.Command):
126
127    commandName = 'hello'
128
129    arguments = [('hello', amp.String()),
130                 ('optional', amp.Boolean(optional=True)),
131                 ('print', amp.Unicode(optional=True)),
132                 ('from', TransportPeer(optional=True)),
133                 ('mixedCase', amp.String(optional=True)),
134                 ('dash-arg', amp.String(optional=True)),
135                 ('underscore_arg', amp.String(optional=True))]
136
137    response = [('hello', amp.String()),
138                ('print', amp.Unicode(optional=True))]
139
140    errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
141
142    fatalErrors = {DeathThreat: 'DEAD'}
143
144class NoAnswerHello(Hello):
145    commandName = Hello.commandName
146    requiresAnswer = False
147
148class FutureHello(amp.Command):
149    commandName = 'hello'
150
151    arguments = [('hello', amp.String()),
152                 ('optional', amp.Boolean(optional=True)),
153                 ('print', amp.Unicode(optional=True)),
154                 ('from', TransportPeer(optional=True)),
155                 ('bonus', amp.String(optional=True)), # addt'l arguments
156                                                       # should generally be
157                                                       # added at the end, and
158                                                       # be optional...
159                 ]
160
161    response = [('hello', amp.String()),
162                ('print', amp.Unicode(optional=True))]
163
164    errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
165
166class WTF(amp.Command):
167    """
168    An example of an invalid command.
169    """
170
171
172class BrokenReturn(amp.Command):
173    """ An example of a perfectly good command, but the handler is going to return
174    None...
175    """
176
177    commandName = 'broken_return'
178
179class Goodbye(amp.Command):
180    # commandName left blank on purpose: this tests implicit command names.
181    response = [('goodbye', amp.String())]
182    responseType = amp.QuitBox
183
184class Howdoyoudo(amp.Command):
185    commandName = 'howdoyoudo'
186    # responseType = amp.QuitBox
187
188class WaitForever(amp.Command):
189    commandName = 'wait_forever'
190
191class GetList(amp.Command):
192    commandName = 'getlist'
193    arguments = [('length', amp.Integer())]
194    response = [('body', amp.AmpList([('x', amp.Integer())]))]
195
196class DontRejectMe(amp.Command):
197    commandName = 'dontrejectme'
198    arguments = [
199            ('magicWord', amp.Unicode()),
200            ('list', amp.AmpList([('name', amp.Unicode())], optional=True)),
201            ]
202    response = [('response', amp.Unicode())]
203
204class SecuredPing(amp.Command):
205    # XXX TODO: actually make this refuse to send over an insecure connection
206    response = [('pinged', amp.Boolean())]
207
208class TestSwitchProto(amp.ProtocolSwitchCommand):
209    commandName = 'Switch-Proto'
210
211    arguments = [
212        ('name', amp.String()),
213        ]
214    errors = {UnknownProtocol: 'UNKNOWN'}
215
216class SingleUseFactory(protocol.ClientFactory):
217    def __init__(self, proto):
218        self.proto = proto
219        self.proto.factory = self
220
221    def buildProtocol(self, addr):
222        p, self.proto = self.proto, None
223        return p
224
225    reasonFailed = None
226
227    def clientConnectionFailed(self, connector, reason):
228        self.reasonFailed = reason
229        return
230
231THING_I_DONT_UNDERSTAND = 'gwebol nargo'
232class ThingIDontUnderstandError(Exception):
233    pass
234
235class FactoryNotifier(amp.AMP):
236    factory = None
237    def connectionMade(self):
238        if self.factory is not None:
239            self.factory.theProto = self
240            if hasattr(self.factory, 'onMade'):
241                self.factory.onMade.callback(None)
242
243    def emitpong(self):
244        from twisted.internet.interfaces import ISSLTransport
245        if not ISSLTransport.providedBy(self.transport):
246            raise DeathThreat("only send secure pings over secure channels")
247        return {'pinged': True}
248    SecuredPing.responder(emitpong)
249
250
251class SimpleSymmetricCommandProtocol(FactoryNotifier):
252    maybeLater = None
253    def __init__(self, onConnLost=None):
254        amp.AMP.__init__(self)
255        self.onConnLost = onConnLost
256
257    def sendHello(self, text):
258        return self.callRemote(Hello, hello=text)
259
260    def sendUnicodeHello(self, text, translation):
261        return self.callRemote(Hello, hello=text, Print=translation)
262
263    greeted = False
264
265    def cmdHello(self, hello, From, optional=None, Print=None,
266                 mixedCase=None, dash_arg=None, underscore_arg=None):
267        assert From == self.transport.getPeer()
268        if hello == THING_I_DONT_UNDERSTAND:
269            raise ThingIDontUnderstandError()
270        if hello.startswith('fuck'):
271            raise UnfriendlyGreeting("Don't be a dick.")
272        if hello == 'die':
273            raise DeathThreat("aieeeeeeeee")
274        result = dict(hello=hello)
275        if Print is not None:
276            result.update(dict(Print=Print))
277        self.greeted = True
278        return result
279    Hello.responder(cmdHello)
280
281    def cmdGetlist(self, length):
282        return {'body': [dict(x=1)] * length}
283    GetList.responder(cmdGetlist)
284
285    def okiwont(self, magicWord, list=None):
286        if list is None:
287            response = u'list omitted'
288        else:
289            response = u'%s accepted' % (list[0]['name'])
290        return dict(response=response)
291    DontRejectMe.responder(okiwont)
292
293    def waitforit(self):
294        self.waiting = defer.Deferred()
295        return self.waiting
296    WaitForever.responder(waitforit)
297
298    def howdo(self):
299        return dict(howdoyoudo='world')
300    Howdoyoudo.responder(howdo)
301
302    def saybye(self):
303        return dict(goodbye="everyone")
304    Goodbye.responder(saybye)
305
306    def switchToTestProtocol(self, fail=False):
307        if fail:
308            name = 'no-proto'
309        else:
310            name = 'test-proto'
311        p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA)
312        return self.callRemote(
313            TestSwitchProto,
314            SingleUseFactory(p), name=name).addCallback(lambda ign: p)
315
316    def switchit(self, name):
317        if name == 'test-proto':
318            return TestProto(self.onConnLost, SWITCH_SERVER_DATA)
319        raise UnknownProtocol(name)
320    TestSwitchProto.responder(switchit)
321
322    def donothing(self):
323        return None
324    BrokenReturn.responder(donothing)
325
326
327class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol):
328    def switchit(self, name):
329        if name == 'test-proto':
330            self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA)
331            self.maybeLater = defer.Deferred()
332            return self.maybeLater
333        raise UnknownProtocol(name)
334    TestSwitchProto.responder(switchit)
335
336class BadNoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
337    def badResponder(self, hello, From, optional=None, Print=None,
338                     mixedCase=None, dash_arg=None, underscore_arg=None):
339        """
340        This responder does nothing and forgets to return a dictionary.
341        """
342    NoAnswerHello.responder(badResponder)
343
344class NoAnswerCommandProtocol(SimpleSymmetricCommandProtocol):
345    def goodNoAnswerResponder(self, hello, From, optional=None, Print=None,
346                              mixedCase=None, dash_arg=None, underscore_arg=None):
347        return dict(hello=hello+"-noanswer")
348    NoAnswerHello.responder(goodNoAnswerResponder)
349
350def connectedServerAndClient(ServerClass=SimpleSymmetricProtocol,
351                             ClientClass=SimpleSymmetricProtocol,
352                             *a, **kw):
353    """Returns a 3-tuple: (client, server, pump)
354    """
355    return iosim.connectedServerAndClient(
356        ServerClass, ClientClass,
357        *a, **kw)
358
359class TotallyDumbProtocol(protocol.Protocol):
360    buf = ''
361    def dataReceived(self, data):
362        self.buf += data
363
364class LiteralAmp(amp.AMP):
365    def __init__(self):
366        self.boxes = []
367
368    def ampBoxReceived(self, box):
369        self.boxes.append(box)
370        return
371
372
373
374class AmpBoxTests(unittest.TestCase):
375    """
376    Test a few essential properties of AMP boxes, mostly with respect to
377    serialization correctness.
378    """
379
380    def test_serializeStr(self):
381        """
382        Make sure that strs serialize to strs.
383        """
384        a = amp.AmpBox(key='value')
385        self.assertEqual(type(a.serialize()), str)
386
387    def test_serializeUnicodeKeyRaises(self):
388        """
389        Verify that TypeError is raised when trying to serialize Unicode keys.
390        """
391        a = amp.AmpBox(**{u'key': 'value'})
392        self.assertRaises(TypeError, a.serialize)
393
394    def test_serializeUnicodeValueRaises(self):
395        """
396        Verify that TypeError is raised when trying to serialize Unicode
397        values.
398        """
399        a = amp.AmpBox(key=u'value')
400        self.assertRaises(TypeError, a.serialize)
401
402
403
404class ParsingTest(unittest.TestCase):
405
406    def test_booleanValues(self):
407        """
408        Verify that the Boolean parser parses 'True' and 'False', but nothing
409        else.
410        """
411        b = amp.Boolean()
412        self.assertEqual(b.fromString("True"), True)
413        self.assertEqual(b.fromString("False"), False)
414        self.assertRaises(TypeError, b.fromString, "ninja")
415        self.assertRaises(TypeError, b.fromString, "true")
416        self.assertRaises(TypeError, b.fromString, "TRUE")
417        self.assertEqual(b.toString(True), 'True')
418        self.assertEqual(b.toString(False), 'False')
419
420    def test_pathValueRoundTrip(self):
421        """
422        Verify the 'Path' argument can parse and emit a file path.
423        """
424        fp = filepath.FilePath(self.mktemp())
425        p = amp.Path()
426        s = p.toString(fp)
427        v = p.fromString(s)
428        self.assertNotIdentical(fp, v) # sanity check
429        self.assertEqual(fp, v)
430
431
432    def test_sillyEmptyThing(self):
433        """
434        Test that empty boxes raise an error; they aren't supposed to be sent
435        on purpose.
436        """
437        a = amp.AMP()
438        return self.assertRaises(amp.NoEmptyBoxes, a.ampBoxReceived, amp.Box())
439
440
441    def test_ParsingRoundTrip(self):
442        """
443        Verify that various kinds of data make it through the encode/parse
444        round-trip unharmed.
445        """
446        c, s, p = connectedServerAndClient(ClientClass=LiteralAmp,
447                                           ServerClass=LiteralAmp)
448
449        SIMPLE = ('simple', 'test')
450        CE = ('ceq', ': ')
451        CR = ('crtest', 'test\r')
452        LF = ('lftest', 'hello\n')
453        NEWLINE = ('newline', 'test\r\none\r\ntwo')
454        NEWLINE2 = ('newline2', 'test\r\none\r\n two')
455        BODYTEST = ('body', 'blah\r\n\r\ntesttest')
456
457        testData = [
458            [SIMPLE],
459            [SIMPLE, BODYTEST],
460            [SIMPLE, CE],
461            [SIMPLE, CR],
462            [SIMPLE, CE, CR, LF],
463            [CE, CR, LF],
464            [SIMPLE, NEWLINE, CE, NEWLINE2],
465            [BODYTEST, SIMPLE, NEWLINE]
466            ]
467
468        for test in testData:
469            jb = amp.Box()
470            jb.update(dict(test))
471            jb._sendTo(c)
472            p.flush()
473            self.assertEqual(s.boxes[-1], jb)
474
475
476
477class FakeLocator(object):
478    """
479    This is a fake implementation of the interface implied by
480    L{CommandLocator}.
481    """
482    def __init__(self):
483        """
484        Remember the given keyword arguments as a set of responders.
485        """
486        self.commands = {}
487
488
489    def locateResponder(self, commandName):
490        """
491        Look up and return a function passed as a keyword argument of the given
492        name to the constructor.
493        """
494        return self.commands[commandName]
495
496
497class FakeSender:
498    """
499    This is a fake implementation of the 'box sender' interface implied by
500    L{AMP}.
501    """
502    def __init__(self):
503        """
504        Create a fake sender and initialize the list of received boxes and
505        unhandled errors.
506        """
507        self.sentBoxes = []
508        self.unhandledErrors = []
509        self.expectedErrors = 0
510
511
512    def expectError(self):
513        """
514        Expect one error, so that the test doesn't fail.
515        """
516        self.expectedErrors += 1
517
518
519    def sendBox(self, box):
520        """
521        Accept a box, but don't do anything.
522        """
523        self.sentBoxes.append(box)
524
525
526    def unhandledError(self, failure):
527        """
528        Deal with failures by instantly re-raising them for easier debugging.
529        """
530        self.expectedErrors -= 1
531        if self.expectedErrors < 0:
532            failure.raiseException()
533        else:
534            self.unhandledErrors.append(failure)
535
536
537
538class CommandDispatchTests(unittest.TestCase):
539    """
540    The AMP CommandDispatcher class dispatches converts AMP boxes into commands
541    and responses using Command.responder decorator.
542
543    Note: Originally, AMP's factoring was such that many tests for this
544    functionality are now implemented as full round-trip tests in L{AMPTest}.
545    Future tests should be written at this level instead, to ensure API
546    compatibility and to provide more granular, readable units of test
547    coverage.
548    """
549
550    def setUp(self):
551        """
552        Create a dispatcher to use.
553        """
554        self.locator = FakeLocator()
555        self.sender = FakeSender()
556        self.dispatcher = amp.BoxDispatcher(self.locator)
557        self.dispatcher.startReceivingBoxes(self.sender)
558
559
560    def test_receivedAsk(self):
561        """
562        L{CommandDispatcher.ampBoxReceived} should locate the appropriate
563        command in its responder lookup, based on the '_ask' key.
564        """
565        received = []
566        def thunk(box):
567            received.append(box)
568            return amp.Box({"hello": "goodbye"})
569        input = amp.Box(_command="hello",
570                        _ask="test-command-id",
571                        hello="world")
572        self.locator.commands['hello'] = thunk
573        self.dispatcher.ampBoxReceived(input)
574        self.assertEqual(received, [input])
575
576
577    def test_sendUnhandledError(self):
578        """
579        L{CommandDispatcher} should relay its unhandled errors in responding to
580        boxes to its boxSender.
581        """
582        err = RuntimeError("something went wrong, oh no")
583        self.sender.expectError()
584        self.dispatcher.unhandledError(Failure(err))
585        self.assertEqual(len(self.sender.unhandledErrors), 1)
586        self.assertEqual(self.sender.unhandledErrors[0].value, err)
587
588
589    def test_unhandledSerializationError(self):
590        """
591        Errors during serialization ought to be relayed to the sender's
592        unhandledError method.
593        """
594        err = RuntimeError("something undefined went wrong")
595        def thunk(result):
596            class BrokenBox(amp.Box):
597                def _sendTo(self, proto):
598                    raise err
599            return BrokenBox()
600        self.locator.commands['hello'] = thunk
601        input = amp.Box(_command="hello",
602                        _ask="test-command-id",
603                        hello="world")
604        self.sender.expectError()
605        self.dispatcher.ampBoxReceived(input)
606        self.assertEqual(len(self.sender.unhandledErrors), 1)
607        self.assertEqual(self.sender.unhandledErrors[0].value, err)
608
609
610    def test_callRemote(self):
611        """
612        L{CommandDispatcher.callRemote} should emit a properly formatted '_ask'
613        box to its boxSender and record an outstanding L{Deferred}.  When a
614        corresponding '_answer' packet is received, the L{Deferred} should be
615        fired, and the results translated via the given L{Command}'s response
616        de-serialization.
617        """
618        D = self.dispatcher.callRemote(Hello, hello='world')
619        self.assertEqual(self.sender.sentBoxes,
620                          [amp.AmpBox(_command="hello",
621                                      _ask="1",
622                                      hello="world")])
623        answers = []
624        D.addCallback(answers.append)
625        self.assertEqual(answers, [])
626        self.dispatcher.ampBoxReceived(amp.AmpBox({'hello': "yay",
627                                                   'print': "ignored",
628                                                   '_answer': "1"}))
629        self.assertEqual(answers, [dict(hello="yay",
630                                         Print=u"ignored")])
631
632
633    def _localCallbackErrorLoggingTest(self, callResult):
634        """
635        Verify that C{callResult} completes with a C{None} result and that an
636        unhandled error has been logged.
637        """
638        finalResult = []
639        callResult.addBoth(finalResult.append)
640
641        self.assertEqual(1, len(self.sender.unhandledErrors))
642        self.assertIsInstance(
643            self.sender.unhandledErrors[0].value, ZeroDivisionError)
644
645        self.assertEqual([None], finalResult)
646
647
648    def test_callRemoteSuccessLocalCallbackErrorLogging(self):
649        """
650        If the last callback on the L{Deferred} returned by C{callRemote} (added
651        by application code calling C{callRemote}) fails, the failure is passed
652        to the sender's C{unhandledError} method.
653        """
654        self.sender.expectError()
655
656        callResult = self.dispatcher.callRemote(Hello, hello='world')
657        callResult.addCallback(lambda result: 1 // 0)
658
659        self.dispatcher.ampBoxReceived(amp.AmpBox({
660                    'hello': "yay", 'print': "ignored", '_answer': "1"}))
661
662        self._localCallbackErrorLoggingTest(callResult)
663
664
665    def test_callRemoteErrorLocalCallbackErrorLogging(self):
666        """
667        Like L{test_callRemoteSuccessLocalCallbackErrorLogging}, but for the
668        case where the L{Deferred} returned by C{callRemote} fails.
669        """
670        self.sender.expectError()
671
672        callResult = self.dispatcher.callRemote(Hello, hello='world')
673        callResult.addErrback(lambda result: 1 // 0)
674
675        self.dispatcher.ampBoxReceived(amp.AmpBox({
676                    '_error': '1', '_error_code': 'bugs',
677                    '_error_description': 'stuff'}))
678
679        self._localCallbackErrorLoggingTest(callResult)
680
681
682
683class SimpleGreeting(amp.Command):
684    """
685    A very simple greeting command that uses a few basic argument types.
686    """
687    commandName = 'simple'
688    arguments = [('greeting', amp.Unicode()),
689                 ('cookie', amp.Integer())]
690    response = [('cookieplus', amp.Integer())]
691
692
693
694class TestLocator(amp.CommandLocator):
695    """
696    A locator which implements a responder to the 'simple' command.
697    """
698    def __init__(self):
699        self.greetings = []
700
701
702    def greetingResponder(self, greeting, cookie):
703        self.greetings.append((greeting, cookie))
704        return dict(cookieplus=cookie + 3)
705    greetingResponder = SimpleGreeting.responder(greetingResponder)
706
707
708
709class OverridingLocator(TestLocator):
710    """
711    A locator which overrides the responder to the 'simple' command.
712    """
713
714    def greetingResponder(self, greeting, cookie):
715        """
716        Return a different cookieplus than L{TestLocator.greetingResponder}.
717        """
718        self.greetings.append((greeting, cookie))
719        return dict(cookieplus=cookie + 4)
720    greetingResponder = SimpleGreeting.responder(greetingResponder)
721
722
723
724class InheritingLocator(OverridingLocator):
725    """
726    This locator should inherit the responder from L{OverridingLocator}.
727    """
728
729
730
731class OverrideLocatorAMP(amp.AMP):
732    def __init__(self):
733        amp.AMP.__init__(self)
734        self.customResponder = object()
735        self.expectations = {"custom": self.customResponder}
736        self.greetings = []
737
738
739    def lookupFunction(self, name):
740        """
741        Override the deprecated lookupFunction function.
742        """
743        if name in self.expectations:
744            result = self.expectations[name]
745            return result
746        else:
747            return super(OverrideLocatorAMP, self).lookupFunction(name)
748
749
750    def greetingResponder(self, greeting, cookie):
751        self.greetings.append((greeting, cookie))
752        return dict(cookieplus=cookie + 3)
753    greetingResponder = SimpleGreeting.responder(greetingResponder)
754
755
756
757
758class CommandLocatorTests(unittest.TestCase):
759    """
760    The CommandLocator should enable users to specify responders to commands as
761    functions that take structured objects, annotated with metadata.
762    """
763
764    def _checkSimpleGreeting(self, locatorClass, expected):
765        """
766        Check that a locator of type C{locatorClass} finds a responder
767        for command named I{simple} and that the found responder answers
768        with the C{expected} result to a C{SimpleGreeting<"ni hao", 5>}
769        command.
770        """
771        locator = locatorClass()
772        responderCallable = locator.locateResponder("simple")
773        result = responderCallable(amp.Box(greeting="ni hao", cookie="5"))
774        def done(values):
775            self.assertEqual(values, amp.AmpBox(cookieplus=str(expected)))
776        return result.addCallback(done)
777
778
779    def test_responderDecorator(self):
780        """
781        A method on a L{CommandLocator} subclass decorated with a L{Command}
782        subclass's L{responder} decorator should be returned from
783        locateResponder, wrapped in logic to serialize and deserialize its
784        arguments.
785        """
786        return self._checkSimpleGreeting(TestLocator, 8)
787
788
789    def test_responderOverriding(self):
790        """
791        L{CommandLocator} subclasses can override a responder inherited from
792        a base class by using the L{Command.responder} decorator to register
793        a new responder method.
794        """
795        return self._checkSimpleGreeting(OverridingLocator, 9)
796
797
798    def test_responderInheritance(self):
799        """
800        Responder lookup follows the same rules as normal method lookup
801        rules, particularly with respect to inheritance.
802        """
803        return self._checkSimpleGreeting(InheritingLocator, 9)
804
805
806    def test_lookupFunctionDeprecatedOverride(self):
807        """
808        Subclasses which override locateResponder under its old name,
809        lookupFunction, should have the override invoked instead.  (This tests
810        an AMP subclass, because in the version of the code that could invoke
811        this deprecated code path, there was no L{CommandLocator}.)
812        """
813        locator = OverrideLocatorAMP()
814        customResponderObject = self.assertWarns(
815            PendingDeprecationWarning,
816            "Override locateResponder, not lookupFunction.",
817            __file__, lambda : locator.locateResponder("custom"))
818        self.assertEqual(locator.customResponder, customResponderObject)
819        # Make sure upcalling works too
820        normalResponderObject = self.assertWarns(
821            PendingDeprecationWarning,
822            "Override locateResponder, not lookupFunction.",
823            __file__, lambda : locator.locateResponder("simple"))
824        result = normalResponderObject(amp.Box(greeting="ni hao", cookie="5"))
825        def done(values):
826            self.assertEqual(values, amp.AmpBox(cookieplus='8'))
827        return result.addCallback(done)
828
829
830    def test_lookupFunctionDeprecatedInvoke(self):
831        """
832        Invoking locateResponder under its old name, lookupFunction, should
833        emit a deprecation warning, but do the same thing.
834        """
835        locator = TestLocator()
836        responderCallable = self.assertWarns(
837            PendingDeprecationWarning,
838            "Call locateResponder, not lookupFunction.", __file__,
839            lambda : locator.lookupFunction("simple"))
840        result = responderCallable(amp.Box(greeting="ni hao", cookie="5"))
841        def done(values):
842            self.assertEqual(values, amp.AmpBox(cookieplus='8'))
843        return result.addCallback(done)
844
845
846
847SWITCH_CLIENT_DATA = 'Success!'
848SWITCH_SERVER_DATA = 'No, really.  Success.'
849
850
851class BinaryProtocolTests(unittest.TestCase):
852    """
853    Tests for L{amp.BinaryBoxProtocol}.
854
855    @ivar _boxSender: After C{startReceivingBoxes} is called, the L{IBoxSender}
856        which was passed to it.
857    """
858
859    def setUp(self):
860        """
861        Keep track of all boxes received by this test in its capacity as an
862        L{IBoxReceiver} implementor.
863        """
864        self.boxes = []
865        self.data = []
866
867
868    def startReceivingBoxes(self, sender):
869        """
870        Implement L{IBoxReceiver.startReceivingBoxes} to just remember the
871        value passed in.
872        """
873        self._boxSender = sender
874
875
876    def ampBoxReceived(self, box):
877        """
878        A box was received by the protocol.
879        """
880        self.boxes.append(box)
881
882    stopReason = None
883    def stopReceivingBoxes(self, reason):
884        """
885        Record the reason that we stopped receiving boxes.
886        """
887        self.stopReason = reason
888
889
890    # fake ITransport
891    def getPeer(self):
892        return 'no peer'
893
894
895    def getHost(self):
896        return 'no host'
897
898
899    def write(self, data):
900        self.data.append(data)
901
902
903    def test_startReceivingBoxes(self):
904        """
905        When L{amp.BinaryBoxProtocol} is connected to a transport, it calls
906        C{startReceivingBoxes} on its L{IBoxReceiver} with itself as the
907        L{IBoxSender} parameter.
908        """
909        protocol = amp.BinaryBoxProtocol(self)
910        protocol.makeConnection(None)
911        self.assertIdentical(self._boxSender, protocol)
912
913
914    def test_sendBoxInStartReceivingBoxes(self):
915        """
916        The L{IBoxReceiver} which is started when L{amp.BinaryBoxProtocol} is
917        connected to a transport can call C{sendBox} on the L{IBoxSender}
918        passed to it before C{startReceivingBoxes} returns and have that box
919        sent.
920        """
921        class SynchronouslySendingReceiver:
922            def startReceivingBoxes(self, sender):
923                sender.sendBox(amp.Box({'foo': 'bar'}))
924
925        transport = StringTransport()
926        protocol = amp.BinaryBoxProtocol(SynchronouslySendingReceiver())
927        protocol.makeConnection(transport)
928        self.assertEqual(
929            transport.value(),
930            '\x00\x03foo\x00\x03bar\x00\x00')
931
932
933    def test_receiveBoxStateMachine(self):
934        """
935        When a binary box protocol receives:
936            * a key
937            * a value
938            * an empty string
939        it should emit a box and send it to its boxReceiver.
940        """
941        a = amp.BinaryBoxProtocol(self)
942        a.stringReceived("hello")
943        a.stringReceived("world")
944        a.stringReceived("")
945        self.assertEqual(self.boxes, [amp.AmpBox(hello="world")])
946
947
948    def test_firstBoxFirstKeyExcessiveLength(self):
949        """
950        L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
951        the first a key it receives is larger than 255.
952        """
953        transport = StringTransport()
954        protocol = amp.BinaryBoxProtocol(self)
955        protocol.makeConnection(transport)
956        protocol.dataReceived('\x01\x00')
957        self.assertTrue(transport.disconnecting)
958
959
960    def test_firstBoxSubsequentKeyExcessiveLength(self):
961        """
962        L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
963        a subsequent key in the first box it receives is larger than 255.
964        """
965        transport = StringTransport()
966        protocol = amp.BinaryBoxProtocol(self)
967        protocol.makeConnection(transport)
968        protocol.dataReceived('\x00\x01k\x00\x01v')
969        self.assertFalse(transport.disconnecting)
970        protocol.dataReceived('\x01\x00')
971        self.assertTrue(transport.disconnecting)
972
973
974    def test_subsequentBoxFirstKeyExcessiveLength(self):
975        """
976        L{amp.BinaryBoxProtocol} drops its connection if the length prefix for
977        the first key in a subsequent box it receives is larger than 255.
978        """
979        transport = StringTransport()
980        protocol = amp.BinaryBoxProtocol(self)
981        protocol.makeConnection(transport)
982        protocol.dataReceived('\x00\x01k\x00\x01v\x00\x00')
983        self.assertFalse(transport.disconnecting)
984        protocol.dataReceived('\x01\x00')
985        self.assertTrue(transport.disconnecting)
986
987
988    def test_excessiveKeyFailure(self):
989        """
990        If L{amp.BinaryBoxProtocol} disconnects because it received a key
991        length prefix which was too large, the L{IBoxReceiver}'s
992        C{stopReceivingBoxes} method is called with a L{TooLong} failure.
993        """
994        protocol = amp.BinaryBoxProtocol(self)
995        protocol.makeConnection(StringTransport())
996        protocol.dataReceived('\x01\x00')
997        protocol.connectionLost(
998            Failure(error.ConnectionDone("simulated connection done")))
999        self.stopReason.trap(amp.TooLong)
1000        self.assertTrue(self.stopReason.value.isKey)
1001        self.assertFalse(self.stopReason.value.isLocal)
1002        self.assertIdentical(self.stopReason.value.value, None)
1003        self.assertIdentical(self.stopReason.value.keyName, None)
1004
1005
1006    def test_unhandledErrorWithTransport(self):
1007        """
1008        L{amp.BinaryBoxProtocol.unhandledError} logs the failure passed to it
1009        and disconnects its transport.
1010        """
1011        transport = StringTransport()
1012        protocol = amp.BinaryBoxProtocol(self)
1013        protocol.makeConnection(transport)
1014        protocol.unhandledError(Failure(RuntimeError("Fake error")))
1015        self.assertEqual(1, len(self.flushLoggedErrors(RuntimeError)))
1016        self.assertTrue(transport.disconnecting)
1017
1018
1019    def test_unhandledErrorWithoutTransport(self):
1020        """
1021        L{amp.BinaryBoxProtocol.unhandledError} completes without error when
1022        there is no associated transport.
1023        """
1024        protocol = amp.BinaryBoxProtocol(self)
1025        protocol.makeConnection(StringTransport())
1026        protocol.connectionLost(Failure(Exception("Simulated")))
1027        protocol.unhandledError(Failure(RuntimeError("Fake error")))
1028        self.assertEqual(1, len(self.flushLoggedErrors(RuntimeError)))
1029
1030
1031    def test_receiveBoxData(self):
1032        """
1033        When a binary box protocol receives the serialized form of an AMP box,
1034        it should emit a similar box to its boxReceiver.
1035        """
1036        a = amp.BinaryBoxProtocol(self)
1037        a.dataReceived(amp.Box({"testKey": "valueTest",
1038                                "anotherKey": "anotherValue"}).serialize())
1039        self.assertEqual(self.boxes,
1040                          [amp.Box({"testKey": "valueTest",
1041                                    "anotherKey": "anotherValue"})])
1042
1043
1044    def test_receiveLongerBoxData(self):
1045        """
1046        An L{amp.BinaryBoxProtocol} can receive serialized AMP boxes with
1047        values of up to (2 ** 16 - 1) bytes.
1048        """
1049        length = (2 ** 16 - 1)
1050        value = 'x' * length
1051        transport = StringTransport()
1052        protocol = amp.BinaryBoxProtocol(self)
1053        protocol.makeConnection(transport)
1054        protocol.dataReceived(amp.Box({'k': value}).serialize())
1055        self.assertEqual(self.boxes, [amp.Box({'k': value})])
1056        self.assertFalse(transport.disconnecting)
1057
1058
1059    def test_sendBox(self):
1060        """
1061        When a binary box protocol sends a box, it should emit the serialized
1062        bytes of that box to its transport.
1063        """
1064        a = amp.BinaryBoxProtocol(self)
1065        a.makeConnection(self)
1066        aBox = amp.Box({"testKey": "valueTest",
1067                        "someData": "hello"})
1068        a.makeConnection(self)
1069        a.sendBox(aBox)
1070        self.assertEqual(''.join(self.data), aBox.serialize())
1071
1072
1073    def test_connectionLostStopSendingBoxes(self):
1074        """
1075        When a binary box protocol loses its connection, it should notify its
1076        box receiver that it has stopped receiving boxes.
1077        """
1078        a = amp.BinaryBoxProtocol(self)
1079        a.makeConnection(self)
1080        connectionFailure = Failure(RuntimeError())
1081        a.connectionLost(connectionFailure)
1082        self.assertIdentical(self.stopReason, connectionFailure)
1083
1084
1085    def test_protocolSwitch(self):
1086        """
1087        L{BinaryBoxProtocol} has the capacity to switch to a different protocol
1088        on a box boundary.  When a protocol is in the process of switching, it
1089        cannot receive traffic.
1090        """
1091        otherProto = TestProto(None, "outgoing data")
1092        test = self
1093        class SwitchyReceiver:
1094            switched = False
1095            def startReceivingBoxes(self, sender):
1096                pass
1097            def ampBoxReceived(self, box):
1098                test.assertFalse(self.switched,
1099                                 "Should only receive one box!")
1100                self.switched = True
1101                a._lockForSwitch()
1102                a._switchTo(otherProto)
1103        a = amp.BinaryBoxProtocol(SwitchyReceiver())
1104        anyOldBox = amp.Box({"include": "lots",
1105                             "of": "data"})
1106        a.makeConnection(self)
1107        # Include a 0-length box at the beginning of the next protocol's data,
1108        # to make sure that AMP doesn't eat the data or try to deliver extra
1109        # boxes either...
1110        moreThanOneBox = anyOldBox.serialize() + "\x00\x00Hello, world!"
1111        a.dataReceived(moreThanOneBox)
1112        self.assertIdentical(otherProto.transport, self)
1113        self.assertEqual("".join(otherProto.data), "\x00\x00Hello, world!")
1114        self.assertEqual(self.data, ["outgoing data"])
1115        a.dataReceived("more data")
1116        self.assertEqual("".join(otherProto.data),
1117                          "\x00\x00Hello, world!more data")
1118        self.assertRaises(amp.ProtocolSwitched, a.sendBox, anyOldBox)
1119
1120
1121    def test_protocolSwitchEmptyBuffer(self):
1122        """
1123        After switching to a different protocol, if no extra bytes beyond
1124        the switch box were delivered, an empty string is not passed to the
1125        switched protocol's C{dataReceived} method.
1126        """
1127        a = amp.BinaryBoxProtocol(self)
1128        a.makeConnection(self)
1129        otherProto = TestProto(None, "")
1130        a._switchTo(otherProto)
1131        self.assertEqual(otherProto.data, [])
1132
1133
1134    def test_protocolSwitchInvalidStates(self):
1135        """
1136        In order to make sure the protocol never gets any invalid data sent
1137        into the middle of a box, it must be locked for switching before it is
1138        switched.  It can only be unlocked if the switch failed, and attempting
1139        to send a box while it is locked should raise an exception.
1140        """
1141        a = amp.BinaryBoxProtocol(self)
1142        a.makeConnection(self)
1143        sampleBox = amp.Box({"some": "data"})
1144        a._lockForSwitch()
1145        self.assertRaises(amp.ProtocolSwitched, a.sendBox, sampleBox)
1146        a._unlockFromSwitch()
1147        a.sendBox(sampleBox)
1148        self.assertEqual(''.join(self.data), sampleBox.serialize())
1149        a._lockForSwitch()
1150        otherProto = TestProto(None, "outgoing data")
1151        a._switchTo(otherProto)
1152        self.assertRaises(amp.ProtocolSwitched, a._unlockFromSwitch)
1153
1154
1155    def test_protocolSwitchLoseConnection(self):
1156        """
1157        When the protocol is switched, it should notify its nested protocol of
1158        disconnection.
1159        """
1160        class Loser(protocol.Protocol):
1161            reason = None
1162            def connectionLost(self, reason):
1163                self.reason = reason
1164        connectionLoser = Loser()
1165        a = amp.BinaryBoxProtocol(self)
1166        a.makeConnection(self)
1167        a._lockForSwitch()
1168        a._switchTo(connectionLoser)
1169        connectionFailure = Failure(RuntimeError())
1170        a.connectionLost(connectionFailure)
1171        self.assertEqual(connectionLoser.reason, connectionFailure)
1172
1173
1174    def test_protocolSwitchLoseClientConnection(self):
1175        """
1176        When the protocol is switched, it should notify its nested client
1177        protocol factory of disconnection.
1178        """
1179        class ClientLoser:
1180            reason = None
1181            def clientConnectionLost(self, connector, reason):
1182                self.reason = reason
1183        a = amp.BinaryBoxProtocol(self)
1184        connectionLoser = protocol.Protocol()
1185        clientLoser = ClientLoser()
1186        a.makeConnection(self)
1187        a._lockForSwitch()
1188        a._switchTo(connectionLoser, clientLoser)
1189        connectionFailure = Failure(RuntimeError())
1190        a.connectionLost(connectionFailure)
1191        self.assertEqual(clientLoser.reason, connectionFailure)
1192
1193
1194
1195class AMPTest(unittest.TestCase):
1196
1197    def test_interfaceDeclarations(self):
1198        """
1199        The classes in the amp module ought to implement the interfaces that
1200        are declared for their benefit.
1201        """
1202        for interface, implementation in [(amp.IBoxSender, amp.BinaryBoxProtocol),
1203                                          (amp.IBoxReceiver, amp.BoxDispatcher),
1204                                          (amp.IResponderLocator, amp.CommandLocator),
1205                                          (amp.IResponderLocator, amp.SimpleStringLocator),
1206                                          (amp.IBoxSender, amp.AMP),
1207                                          (amp.IBoxReceiver, amp.AMP),
1208                                          (amp.IResponderLocator, amp.AMP)]:
1209            self.failUnless(interface.implementedBy(implementation),
1210                            "%s does not implements(%s)" % (implementation, interface))
1211
1212
1213    def test_helloWorld(self):
1214        """
1215        Verify that a simple command can be sent and its response received with
1216        the simple low-level string-based API.
1217        """
1218        c, s, p = connectedServerAndClient()
1219        L = []
1220        HELLO = 'world'
1221        c.sendHello(HELLO).addCallback(L.append)
1222        p.flush()
1223        self.assertEqual(L[0]['hello'], HELLO)
1224
1225
1226    def test_wireFormatRoundTrip(self):
1227        """
1228        Verify that mixed-case, underscored and dashed arguments are mapped to
1229        their python names properly.
1230        """
1231        c, s, p = connectedServerAndClient()
1232        L = []
1233        HELLO = 'world'
1234        c.sendHello(HELLO).addCallback(L.append)
1235        p.flush()
1236        self.assertEqual(L[0]['hello'], HELLO)
1237
1238
1239    def test_helloWorldUnicode(self):
1240        """
1241        Verify that unicode arguments can be encoded and decoded.
1242        """
1243        c, s, p = connectedServerAndClient(
1244            ServerClass=SimpleSymmetricCommandProtocol,
1245            ClientClass=SimpleSymmetricCommandProtocol)
1246        L = []
1247        HELLO = 'world'
1248        HELLO_UNICODE = 'wor\u1234ld'
1249        c.sendUnicodeHello(HELLO, HELLO_UNICODE).addCallback(L.append)
1250        p.flush()
1251        self.assertEqual(L[0]['hello'], HELLO)
1252        self.assertEqual(L[0]['Print'], HELLO_UNICODE)
1253
1254
1255    def test_callRemoteStringRequiresAnswerFalse(self):
1256        """
1257        L{BoxDispatcher.callRemoteString} returns C{None} if C{requiresAnswer}
1258        is C{False}.
1259        """
1260        c, s, p = connectedServerAndClient()
1261        ret = c.callRemoteString("WTF", requiresAnswer=False)
1262        self.assertIdentical(ret, None)
1263
1264
1265    def test_unknownCommandLow(self):
1266        """
1267        Verify that unknown commands using low-level APIs will be rejected with an
1268        error, but will NOT terminate the connection.
1269        """
1270        c, s, p = connectedServerAndClient()
1271        L = []
1272        def clearAndAdd(e):
1273            """
1274            You can't propagate the error...
1275            """
1276            e.trap(amp.UnhandledCommand)
1277            return "OK"
1278        c.callRemoteString("WTF").addErrback(clearAndAdd).addCallback(L.append)
1279        p.flush()
1280        self.assertEqual(L.pop(), "OK")
1281        HELLO = 'world'
1282        c.sendHello(HELLO).addCallback(L.append)
1283        p.flush()
1284        self.assertEqual(L[0]['hello'], HELLO)
1285
1286
1287    def test_unknownCommandHigh(self):
1288        """
1289        Verify that unknown commands using high-level APIs will be rejected with an
1290        error, but will NOT terminate the connection.
1291        """
1292        c, s, p = connectedServerAndClient()
1293        L = []
1294        def clearAndAdd(e):
1295            """
1296            You can't propagate the error...
1297            """
1298            e.trap(amp.UnhandledCommand)
1299            return "OK"
1300        c.callRemote(WTF).addErrback(clearAndAdd).addCallback(L.append)
1301        p.flush()
1302        self.assertEqual(L.pop(), "OK")
1303        HELLO = 'world'
1304        c.sendHello(HELLO).addCallback(L.append)
1305        p.flush()
1306        self.assertEqual(L[0]['hello'], HELLO)
1307
1308
1309    def test_brokenReturnValue(self):
1310        """
1311        It can be very confusing if you write some code which responds to a
1312        command, but gets the return value wrong.  Most commonly you end up
1313        returning None instead of a dictionary.
1314
1315        Verify that if that happens, the framework logs a useful error.
1316        """
1317        L = []
1318        SimpleSymmetricCommandProtocol().dispatchCommand(
1319            amp.AmpBox(_command=BrokenReturn.commandName)).addErrback(L.append)
1320        L[0].trap(amp.BadLocalReturn)
1321        self.failUnlessIn('None', repr(L[0].value))
1322
1323
1324    def test_unknownArgument(self):
1325        """
1326        Verify that unknown arguments are ignored, and not passed to a Python
1327        function which can't accept them.
1328        """
1329        c, s, p = connectedServerAndClient(
1330            ServerClass=SimpleSymmetricCommandProtocol,
1331            ClientClass=SimpleSymmetricCommandProtocol)
1332        L = []
1333        HELLO = 'world'
1334        # c.sendHello(HELLO).addCallback(L.append)
1335        c.callRemote(FutureHello,
1336                     hello=HELLO,
1337                     bonus="I'm not in the book!").addCallback(
1338            L.append)
1339        p.flush()
1340        self.assertEqual(L[0]['hello'], HELLO)
1341
1342
1343    def test_simpleReprs(self):
1344        """
1345        Verify that the various Box objects repr properly, for debugging.
1346        """
1347        self.assertEqual(type(repr(amp._SwitchBox('a'))), str)
1348        self.assertEqual(type(repr(amp.QuitBox())), str)
1349        self.assertEqual(type(repr(amp.AmpBox())), str)
1350        self.failUnless("AmpBox" in repr(amp.AmpBox()))
1351
1352
1353    def test_innerProtocolInRepr(self):
1354        """
1355        Verify that L{AMP} objects output their innerProtocol when set.
1356        """
1357        otherProto = TestProto(None, "outgoing data")
1358        a = amp.AMP()
1359        a.innerProtocol = otherProto
1360
1361        self.assertEqual(
1362            repr(a), "<AMP inner <TestProto #%d> at 0x%x>" % (
1363                otherProto.instanceId, id(a)))
1364
1365
1366    def test_innerProtocolNotInRepr(self):
1367        """
1368        Verify that L{AMP} objects do not output 'inner' when no innerProtocol
1369        is set.
1370        """
1371        a = amp.AMP()
1372        self.assertEqual(repr(a), "<AMP at 0x%x>" % (id(a),))
1373
1374
1375    def test_simpleSSLRepr(self):
1376        """
1377        L{amp._TLSBox.__repr__} returns a string.
1378        """
1379        self.assertEqual(type(repr(amp._TLSBox())), str)
1380
1381    test_simpleSSLRepr.skip = skipSSL
1382
1383
1384    def test_keyTooLong(self):
1385        """
1386        Verify that a key that is too long will immediately raise a synchronous
1387        exception.
1388        """
1389        c, s, p = connectedServerAndClient()
1390        x = "H" * (0xff+1)
1391        tl = self.assertRaises(amp.TooLong,
1392                               c.callRemoteString, "Hello",
1393                               **{x: "hi"})
1394        self.assertTrue(tl.isKey)
1395        self.assertTrue(tl.isLocal)
1396        self.assertIdentical(tl.keyName, None)
1397        self.assertEqual(tl.value, x)
1398        self.assertIn(str(len(x)), repr(tl))
1399        self.assertIn("key", repr(tl))
1400
1401
1402    def test_valueTooLong(self):
1403        """
1404        Verify that attempting to send value longer than 64k will immediately
1405        raise an exception.
1406        """
1407        c, s, p = connectedServerAndClient()
1408        x = "H" * (0xffff+1)
1409        tl = self.assertRaises(amp.TooLong, c.sendHello, x)
1410        p.flush()
1411        self.failIf(tl.isKey)
1412        self.failUnless(tl.isLocal)
1413        self.assertEqual(tl.keyName, 'hello')
1414        self.failUnlessIdentical(tl.value, x)
1415        self.failUnless(str(len(x)) in repr(tl))
1416        self.failUnless("value" in repr(tl))
1417        self.failUnless('hello' in repr(tl))
1418
1419
1420    def test_helloWorldCommand(self):
1421        """
1422        Verify that a simple command can be sent and its response received with
1423        the high-level value parsing API.
1424        """
1425        c, s, p = connectedServerAndClient(
1426            ServerClass=SimpleSymmetricCommandProtocol,
1427            ClientClass=SimpleSymmetricCommandProtocol)
1428        L = []
1429        HELLO = 'world'
1430        c.sendHello(HELLO).addCallback(L.append)
1431        p.flush()
1432        self.assertEqual(L[0]['hello'], HELLO)
1433
1434
1435    def test_helloErrorHandling(self):
1436        """
1437        Verify that if a known error type is raised and handled, it will be
1438        properly relayed to the other end of the connection and translated into
1439        an exception, and no error will be logged.
1440        """
1441        L=[]
1442        c, s, p = connectedServerAndClient(
1443            ServerClass=SimpleSymmetricCommandProtocol,
1444            ClientClass=SimpleSymmetricCommandProtocol)
1445        HELLO = 'fuck you'
1446        c.sendHello(HELLO).addErrback(L.append)
1447        p.flush()
1448        L[0].trap(UnfriendlyGreeting)
1449        self.assertEqual(str(L[0].value), "Don't be a dick.")
1450
1451
1452    def test_helloFatalErrorHandling(self):
1453        """
1454        Verify that if a known, fatal error type is raised and handled, it will
1455        be properly relayed to the other end of the connection and translated
1456        into an exception, no error will be logged, and the connection will be
1457        terminated.
1458        """
1459        L=[]
1460        c, s, p = connectedServerAndClient(
1461            ServerClass=SimpleSymmetricCommandProtocol,
1462            ClientClass=SimpleSymmetricCommandProtocol)
1463        HELLO = 'die'
1464        c.sendHello(HELLO).addErrback(L.append)
1465        p.flush()
1466        L.pop().trap(DeathThreat)
1467        c.sendHello(HELLO).addErrback(L.append)
1468        p.flush()
1469        L.pop().trap(error.ConnectionDone)
1470
1471
1472
1473    def test_helloNoErrorHandling(self):
1474        """
1475        Verify that if an unknown error type is raised, it will be relayed to
1476        the other end of the connection and translated into an exception, it
1477        will be logged, and then the connection will be dropped.
1478        """
1479        L=[]
1480        c, s, p = connectedServerAndClient(
1481            ServerClass=SimpleSymmetricCommandProtocol,
1482            ClientClass=SimpleSymmetricCommandProtocol)
1483        HELLO = THING_I_DONT_UNDERSTAND
1484        c.sendHello(HELLO).addErrback(L.append)
1485        p.flush()
1486        ure = L.pop()
1487        ure.trap(amp.UnknownRemoteError)
1488        c.sendHello(HELLO).addErrback(L.append)
1489        cl = L.pop()
1490        cl.trap(error.ConnectionDone)
1491        # The exception should have been logged.
1492        self.failUnless(self.flushLoggedErrors(ThingIDontUnderstandError))
1493
1494
1495
1496    def test_lateAnswer(self):
1497        """
1498        Verify that a command that does not get answered until after the
1499        connection terminates will not cause any errors.
1500        """
1501        c, s, p = connectedServerAndClient(
1502            ServerClass=SimpleSymmetricCommandProtocol,
1503            ClientClass=SimpleSymmetricCommandProtocol)
1504        L = []
1505        c.callRemote(WaitForever).addErrback(L.append)
1506        p.flush()
1507        self.assertEqual(L, [])
1508        s.transport.loseConnection()
1509        p.flush()
1510        L.pop().trap(error.ConnectionDone)
1511        # Just make sure that it doesn't error...
1512        s.waiting.callback({})
1513        return s.waiting
1514
1515
1516    def test_requiresNoAnswer(self):
1517        """
1518        Verify that a command that requires no answer is run.
1519        """
1520        c, s, p = connectedServerAndClient(
1521            ServerClass=SimpleSymmetricCommandProtocol,
1522            ClientClass=SimpleSymmetricCommandProtocol)
1523        HELLO = 'world'
1524        c.callRemote(NoAnswerHello, hello=HELLO)
1525        p.flush()
1526        self.failUnless(s.greeted)
1527
1528
1529    def test_requiresNoAnswerFail(self):
1530        """
1531        Verify that commands sent after a failed no-answer request do not complete.
1532        """
1533        L=[]
1534        c, s, p = connectedServerAndClient(
1535            ServerClass=SimpleSymmetricCommandProtocol,
1536            ClientClass=SimpleSymmetricCommandProtocol)
1537        HELLO = 'fuck you'
1538        c.callRemote(NoAnswerHello, hello=HELLO)
1539        p.flush()
1540        # This should be logged locally.
1541        self.failUnless(self.flushLoggedErrors(amp.RemoteAmpError))
1542        HELLO = 'world'
1543        c.callRemote(Hello, hello=HELLO).addErrback(L.append)
1544        p.flush()
1545        L.pop().trap(error.ConnectionDone)
1546        self.failIf(s.greeted)
1547
1548
1549    def test_noAnswerResponderBadAnswer(self):
1550        """
1551        Verify that responders of requiresAnswer=False commands have to return
1552        a dictionary anyway.
1553
1554        (requiresAnswer is a hint from the _client_ - the server may be called
1555        upon to answer commands in any case, if the client wants to know when
1556        they complete.)
1557        """
1558        c, s, p = connectedServerAndClient(
1559            ServerClass=BadNoAnswerCommandProtocol,
1560            ClientClass=SimpleSymmetricCommandProtocol)
1561        c.callRemote(NoAnswerHello, hello="hello")
1562        p.flush()
1563        le = self.flushLoggedErrors(amp.BadLocalReturn)
1564        self.assertEqual(len(le), 1)
1565
1566
1567    def test_noAnswerResponderAskedForAnswer(self):
1568        """
1569        Verify that responders with requiresAnswer=False will actually respond
1570        if the client sets requiresAnswer=True.  In other words, verify that
1571        requiresAnswer is a hint honored only by the client.
1572        """
1573        c, s, p = connectedServerAndClient(
1574            ServerClass=NoAnswerCommandProtocol,
1575            ClientClass=SimpleSymmetricCommandProtocol)
1576        L = []
1577        c.callRemote(Hello, hello="Hello!").addCallback(L.append)
1578        p.flush()
1579        self.assertEqual(len(L), 1)
1580        self.assertEqual(L, [dict(hello="Hello!-noanswer",
1581                                   Print=None)])  # Optional response argument
1582
1583
1584    def test_ampListCommand(self):
1585        """
1586        Test encoding of an argument that uses the AmpList encoding.
1587        """
1588        c, s, p = connectedServerAndClient(
1589            ServerClass=SimpleSymmetricCommandProtocol,
1590            ClientClass=SimpleSymmetricCommandProtocol)
1591        L = []
1592        c.callRemote(GetList, length=10).addCallback(L.append)
1593        p.flush()
1594        values = L.pop().get('body')
1595        self.assertEqual(values, [{'x': 1}] * 10)
1596
1597
1598    def test_optionalAmpListOmitted(self):
1599        """
1600        Sending a command with an omitted AmpList argument that is
1601        designated as optional does not raise an InvalidSignature error.
1602        """
1603        c, s, p = connectedServerAndClient(
1604            ServerClass=SimpleSymmetricCommandProtocol,
1605            ClientClass=SimpleSymmetricCommandProtocol)
1606        L = []
1607        c.callRemote(DontRejectMe, magicWord=u'please').addCallback(L.append)
1608        p.flush()
1609        response = L.pop().get('response')
1610        self.assertEqual(response, 'list omitted')
1611
1612
1613    def test_optionalAmpListPresent(self):
1614        """
1615        Sanity check that optional AmpList arguments are processed normally.
1616        """
1617        c, s, p = connectedServerAndClient(
1618            ServerClass=SimpleSymmetricCommandProtocol,
1619            ClientClass=SimpleSymmetricCommandProtocol)
1620        L = []
1621        c.callRemote(DontRejectMe, magicWord=u'please',
1622                list=[{'name': 'foo'}]).addCallback(L.append)
1623        p.flush()
1624        response = L.pop().get('response')
1625        self.assertEqual(response, 'foo accepted')
1626
1627
1628    def test_failEarlyOnArgSending(self):
1629        """
1630        Verify that if we pass an invalid argument list (omitting an argument),
1631        an exception will be raised.
1632        """
1633        self.assertRaises(amp.InvalidSignature, Hello)
1634
1635
1636    def test_doubleProtocolSwitch(self):
1637        """
1638        As a debugging aid, a protocol system should raise a
1639        L{ProtocolSwitched} exception when asked to switch a protocol that is
1640        already switched.
1641        """
1642        serverDeferred = defer.Deferred()
1643        serverProto = SimpleSymmetricCommandProtocol(serverDeferred)
1644        clientDeferred = defer.Deferred()
1645        clientProto = SimpleSymmetricCommandProtocol(clientDeferred)
1646        c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
1647                                           ClientClass=lambda: clientProto)
1648        def switched(result):
1649            self.assertRaises(amp.ProtocolSwitched, c.switchToTestProtocol)
1650            self.testSucceeded = True
1651        c.switchToTestProtocol().addCallback(switched)
1652        p.flush()
1653        self.failUnless(self.testSucceeded)
1654
1655
1656    def test_protocolSwitch(self, switcher=SimpleSymmetricCommandProtocol,
1657                            spuriousTraffic=False,
1658                            spuriousError=False):
1659        """
1660        Verify that it is possible to switch to another protocol mid-connection and
1661        send data to it successfully.
1662        """
1663        self.testSucceeded = False
1664
1665        serverDeferred = defer.Deferred()
1666        serverProto = switcher(serverDeferred)
1667        clientDeferred = defer.Deferred()
1668        clientProto = switcher(clientDeferred)
1669        c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
1670                                           ClientClass=lambda: clientProto)
1671
1672        if spuriousTraffic:
1673            wfdr = []           # remote
1674            c.callRemote(WaitForever).addErrback(wfdr.append)
1675        switchDeferred = c.switchToTestProtocol()
1676        if spuriousTraffic:
1677            self.assertRaises(amp.ProtocolSwitched, c.sendHello, 'world')
1678
1679        def cbConnsLost(((serverSuccess, serverData),
1680                         (clientSuccess, clientData))):
1681            self.failUnless(serverSuccess)
1682            self.failUnless(clientSuccess)
1683            self.assertEqual(''.join(serverData), SWITCH_CLIENT_DATA)
1684            self.assertEqual(''.join(clientData), SWITCH_SERVER_DATA)
1685            self.testSucceeded = True
1686
1687        def cbSwitch(proto):
1688            return defer.DeferredList(
1689                [serverDeferred, clientDeferred]).addCallback(cbConnsLost)
1690
1691        switchDeferred.addCallback(cbSwitch)
1692        p.flush()
1693        if serverProto.maybeLater is not None:
1694            serverProto.maybeLater.callback(serverProto.maybeLaterProto)
1695            p.flush()
1696        if spuriousTraffic:
1697            # switch is done here; do this here to make sure that if we're
1698            # going to corrupt the connection, we do it before it's closed.
1699            if spuriousError:
1700                s.waiting.errback(amp.RemoteAmpError(
1701                        "SPURIOUS",
1702                        "Here's some traffic in the form of an error."))
1703            else:
1704                s.waiting.callback({})
1705            p.flush()
1706        c.transport.loseConnection() # close it
1707        p.flush()
1708        self.failUnless(self.testSucceeded)
1709
1710
1711    def test_protocolSwitchDeferred(self):
1712        """
1713        Verify that protocol-switching even works if the value returned from
1714        the command that does the switch is deferred.
1715        """
1716        return self.test_protocolSwitch(switcher=DeferredSymmetricCommandProtocol)
1717
1718
1719    def test_protocolSwitchFail(self, switcher=SimpleSymmetricCommandProtocol):
1720        """
1721        Verify that if we try to switch protocols and it fails, the connection
1722        stays up and we can go back to speaking AMP.
1723        """
1724        self.testSucceeded = False
1725
1726        serverDeferred = defer.Deferred()
1727        serverProto = switcher(serverDeferred)
1728        clientDeferred = defer.Deferred()
1729        clientProto = switcher(clientDeferred)
1730        c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
1731                                           ClientClass=lambda: clientProto)
1732        L = []
1733        c.switchToTestProtocol(fail=True).addErrback(L.append)
1734        p.flush()
1735        L.pop().trap(UnknownProtocol)
1736        self.failIf(self.testSucceeded)
1737        # It's a known error, so let's send a "hello" on the same connection;
1738        # it should work.
1739        c.sendHello('world').addCallback(L.append)
1740        p.flush()
1741        self.assertEqual(L.pop()['hello'], 'world')
1742
1743
1744    def test_trafficAfterSwitch(self):
1745        """
1746        Verify that attempts to send traffic after a switch will not corrupt
1747        the nested protocol.
1748        """
1749        return self.test_protocolSwitch(spuriousTraffic=True)
1750
1751
1752    def test_errorAfterSwitch(self):
1753        """
1754        Returning an error after a protocol switch should record the underlying
1755        error.
1756        """
1757        return self.test_protocolSwitch(spuriousTraffic=True,
1758                                        spuriousError=True)
1759
1760
1761    def test_quitBoxQuits(self):
1762        """
1763        Verify that commands with a responseType of QuitBox will in fact
1764        terminate the connection.
1765        """
1766        c, s, p = connectedServerAndClient(
1767            ServerClass=SimpleSymmetricCommandProtocol,
1768            ClientClass=SimpleSymmetricCommandProtocol)
1769
1770        L = []
1771        HELLO = 'world'
1772        GOODBYE = 'everyone'
1773        c.sendHello(HELLO).addCallback(L.append)
1774        p.flush()
1775        self.assertEqual(L.pop()['hello'], HELLO)
1776        c.callRemote(Goodbye).addCallback(L.append)
1777        p.flush()
1778        self.assertEqual(L.pop()['goodbye'], GOODBYE)
1779        c.sendHello(HELLO).addErrback(L.append)
1780        L.pop().trap(error.ConnectionDone)
1781
1782
1783    def test_basicLiteralEmit(self):
1784        """
1785        Verify that the command dictionaries for a callRemoteN look correct
1786        after being serialized and parsed.
1787        """
1788        c, s, p = connectedServerAndClient()
1789        L = []
1790        s.ampBoxReceived = L.append
1791        c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test',
1792                     dash_arg='x', underscore_arg='y')
1793        p.flush()
1794        self.assertEqual(len(L), 1)
1795        for k, v in [('_command', Hello.commandName),
1796                     ('hello', 'hello test'),
1797                     ('mixedCase', 'mixed case arg test'),
1798                     ('dash-arg', 'x'),
1799                     ('underscore_arg', 'y')]:
1800            self.assertEqual(L[-1].pop(k), v)
1801        L[-1].pop('_ask')
1802        self.assertEqual(L[-1], {})
1803
1804
1805    def test_basicStructuredEmit(self):
1806        """
1807        Verify that a call similar to basicLiteralEmit's is handled properly with
1808        high-level quoting and passing to Python methods, and that argument
1809        names are correctly handled.
1810        """
1811        L = []
1812        class StructuredHello(amp.AMP):
1813            def h(self, *a, **k):
1814                L.append((a, k))
1815                return dict(hello='aaa')
1816            Hello.responder(h)
1817        c, s, p = connectedServerAndClient(ServerClass=StructuredHello)
1818        c.callRemote(Hello, hello='hello test', mixedCase='mixed case arg test',
1819                     dash_arg='x', underscore_arg='y').addCallback(L.append)
1820        p.flush()
1821        self.assertEqual(len(L), 2)
1822        self.assertEqual(L[0],
1823                          ((), dict(
1824                    hello='hello test',
1825                    mixedCase='mixed case arg test',
1826                    dash_arg='x',
1827                    underscore_arg='y',
1828                    From=s.transport.getPeer(),
1829
1830                    # XXX - should optional arguments just not be passed?
1831                    # passing None seems a little odd, looking at the way it
1832                    # turns out here... -glyph
1833                    Print=None,
1834                    optional=None,
1835                    )))
1836        self.assertEqual(L[1], dict(Print=None, hello='aaa'))
1837
1838class PretendRemoteCertificateAuthority:
1839    def checkIsPretendRemote(self):
1840        return True
1841
1842class IOSimCert:
1843    verifyCount = 0
1844
1845    def options(self, *ign):
1846        return self
1847
1848    def iosimVerify(self, otherCert):
1849        """
1850        This isn't a real certificate, and wouldn't work on a real socket, but
1851        iosim specifies a different API so that we don't have to do any crypto
1852        math to demonstrate that the right functions get called in the right
1853        places.
1854        """
1855        assert otherCert is self
1856        self.verifyCount += 1
1857        return True
1858
1859class OKCert(IOSimCert):
1860    def options(self, x):
1861        assert x.checkIsPretendRemote()
1862        return self
1863
1864class GrumpyCert(IOSimCert):
1865    def iosimVerify(self, otherCert):
1866        self.verifyCount += 1
1867        return False
1868
1869class DroppyCert(IOSimCert):
1870    def __init__(self, toDrop):
1871        self.toDrop = toDrop
1872
1873    def iosimVerify(self, otherCert):
1874        self.verifyCount += 1
1875        self.toDrop.loseConnection()
1876        return True
1877
1878class SecurableProto(FactoryNotifier):
1879
1880    factory = None
1881
1882    def verifyFactory(self):
1883        return [PretendRemoteCertificateAuthority()]
1884
1885    def getTLSVars(self):
1886        cert = self.certFactory()
1887        verify = self.verifyFactory()
1888        return dict(
1889            tls_localCertificate=cert,
1890            tls_verifyAuthorities=verify)
1891    amp.StartTLS.responder(getTLSVars)
1892
1893
1894
1895class TLSTest(unittest.TestCase):
1896    def test_startingTLS(self):
1897        """
1898        Verify that starting TLS and succeeding at handshaking sends all the
1899        notifications to all the right places.
1900        """
1901        cli, svr, p = connectedServerAndClient(
1902            ServerClass=SecurableProto,
1903            ClientClass=SecurableProto)
1904
1905        okc = OKCert()
1906        svr.certFactory = lambda : okc
1907
1908        cli.callRemote(
1909            amp.StartTLS, tls_localCertificate=okc,
1910            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
1911
1912        # let's buffer something to be delivered securely
1913        L = []
1914        cli.callRemote(SecuredPing).addCallback(L.append)
1915        p.flush()
1916        # once for client once for server
1917        self.assertEqual(okc.verifyCount, 2)
1918        L = []
1919        cli.callRemote(SecuredPing).addCallback(L.append)
1920        p.flush()
1921        self.assertEqual(L[0], {'pinged': True})
1922
1923
1924    def test_startTooManyTimes(self):
1925        """
1926        Verify that the protocol will complain if we attempt to renegotiate TLS,
1927        which we don't support.
1928        """
1929        cli, svr, p = connectedServerAndClient(
1930            ServerClass=SecurableProto,
1931            ClientClass=SecurableProto)
1932
1933        okc = OKCert()
1934        svr.certFactory = lambda : okc
1935
1936        cli.callRemote(amp.StartTLS,
1937                       tls_localCertificate=okc,
1938                       tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
1939        p.flush()
1940        cli.noPeerCertificate = True # this is totally fake
1941        self.assertRaises(
1942            amp.OnlyOneTLS,
1943            cli.callRemote,
1944            amp.StartTLS,
1945            tls_localCertificate=okc,
1946            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()])
1947
1948
1949    def test_negotiationFailed(self):
1950        """
1951        Verify that starting TLS and failing on both sides at handshaking sends
1952        notifications to all the right places and terminates the connection.
1953        """
1954
1955        badCert = GrumpyCert()
1956
1957        cli, svr, p = connectedServerAndClient(
1958            ServerClass=SecurableProto,
1959            ClientClass=SecurableProto)
1960        svr.certFactory = lambda : badCert
1961
1962        cli.callRemote(amp.StartTLS,
1963                       tls_localCertificate=badCert)
1964
1965        p.flush()
1966        # once for client once for server - but both fail
1967        self.assertEqual(badCert.verifyCount, 2)
1968        d = cli.callRemote(SecuredPing)
1969        p.flush()
1970        self.assertFailure(d, iosim.NativeOpenSSLError)
1971
1972
1973    def test_negotiationFailedByClosing(self):
1974        """
1975        Verify that starting TLS and failing by way of a lost connection
1976        notices that it is probably an SSL problem.
1977        """
1978
1979        cli, svr, p = connectedServerAndClient(
1980            ServerClass=SecurableProto,
1981            ClientClass=SecurableProto)
1982        droppyCert = DroppyCert(svr.transport)
1983        svr.certFactory = lambda : droppyCert
1984
1985        cli.callRemote(amp.StartTLS, tls_localCertificate=droppyCert)
1986
1987        p.flush()
1988
1989        self.assertEqual(droppyCert.verifyCount, 2)
1990
1991        d = cli.callRemote(SecuredPing)
1992        p.flush()
1993
1994        # it might be a good idea to move this exception somewhere more
1995        # reasonable.
1996        self.assertFailure(d, error.PeerVerifyError)
1997
1998    skip = skipSSL
1999
2000
2001
2002class TLSNotAvailableTest(unittest.TestCase):
2003    """
2004    Tests what happened when ssl is not available in current installation.
2005    """
2006
2007    def setUp(self):
2008        """
2009        Disable ssl in amp.
2010        """
2011        self.ssl = amp.ssl
2012        amp.ssl = None
2013
2014
2015    def tearDown(self):
2016        """
2017        Restore ssl module.
2018        """
2019        amp.ssl = self.ssl
2020
2021
2022    def test_callRemoteError(self):
2023        """
2024        Check that callRemote raises an exception when called with a
2025        L{amp.StartTLS}.
2026        """
2027        cli, svr, p = connectedServerAndClient(
2028            ServerClass=SecurableProto,
2029            ClientClass=SecurableProto)
2030
2031        okc = OKCert()
2032        svr.certFactory = lambda : okc
2033
2034        return self.assertFailure(cli.callRemote(
2035            amp.StartTLS, tls_localCertificate=okc,
2036            tls_verifyAuthorities=[PretendRemoteCertificateAuthority()]),
2037            RuntimeError)
2038
2039
2040    def test_messageReceivedError(self):
2041        """
2042        When a client with SSL enabled talks to a server without SSL, it
2043        should return a meaningful error.
2044        """
2045        svr = SecurableProto()
2046        okc = OKCert()
2047        svr.certFactory = lambda : okc
2048        box = amp.Box()
2049        box['_command'] = 'StartTLS'
2050        box['_ask'] = '1'
2051        boxes = []
2052        svr.sendBox = boxes.append
2053        svr.makeConnection(StringTransport())
2054        svr.ampBoxReceived(box)
2055        self.assertEqual(boxes,
2056            [{'_error_code': 'TLS_ERROR',
2057              '_error': '1',
2058              '_error_description': 'TLS not available'}])
2059
2060
2061
2062class InheritedError(Exception):
2063    """
2064    This error is used to check inheritance.
2065    """
2066
2067
2068
2069class OtherInheritedError(Exception):
2070    """
2071    This is a distinct error for checking inheritance.
2072    """
2073
2074
2075
2076class BaseCommand(amp.Command):
2077    """
2078    This provides a command that will be subclassed.
2079    """
2080    errors = {InheritedError: 'INHERITED_ERROR'}
2081
2082
2083
2084class InheritedCommand(BaseCommand):
2085    """
2086    This is a command which subclasses another command but does not override
2087    anything.
2088    """
2089
2090
2091
2092class AddErrorsCommand(BaseCommand):
2093    """
2094    This is a command which subclasses another command but adds errors to the
2095    list.
2096    """
2097    arguments = [('other', amp.Boolean())]
2098    errors = {OtherInheritedError: 'OTHER_INHERITED_ERROR'}
2099
2100
2101
2102class NormalCommandProtocol(amp.AMP):
2103    """
2104    This is a protocol which responds to L{BaseCommand}, and is used to test
2105    that inheritance does not interfere with the normal handling of errors.
2106    """
2107    def resp(self):
2108        raise InheritedError()
2109    BaseCommand.responder(resp)
2110
2111
2112
2113class InheritedCommandProtocol(amp.AMP):
2114    """
2115    This is a protocol which responds to L{InheritedCommand}, and is used to
2116    test that inherited commands inherit their bases' errors if they do not
2117    respond to any of their own.
2118    """
2119    def resp(self):
2120        raise InheritedError()
2121    InheritedCommand.responder(resp)
2122
2123
2124
2125class AddedCommandProtocol(amp.AMP):
2126    """
2127    This is a protocol which responds to L{AddErrorsCommand}, and is used to
2128    test that inherited commands can add their own new types of errors, but
2129    still respond in the same way to their parents types of errors.
2130    """
2131    def resp(self, other):
2132        if other:
2133            raise OtherInheritedError()
2134        else:
2135            raise InheritedError()
2136    AddErrorsCommand.responder(resp)
2137
2138
2139
2140class CommandInheritanceTests(unittest.TestCase):
2141    """
2142    These tests verify that commands inherit error conditions properly.
2143    """
2144
2145    def errorCheck(self, err, proto, cmd, **kw):
2146        """
2147        Check that the appropriate kind of error is raised when a given command
2148        is sent to a given protocol.
2149        """
2150        c, s, p = connectedServerAndClient(ServerClass=proto,
2151                                           ClientClass=proto)
2152        d = c.callRemote(cmd, **kw)
2153        d2 = self.failUnlessFailure(d, err)
2154        p.flush()
2155        return d2
2156
2157
2158    def test_basicErrorPropagation(self):
2159        """
2160        Verify that errors specified in a superclass are respected normally
2161        even if it has subclasses.
2162        """
2163        return self.errorCheck(
2164            InheritedError, NormalCommandProtocol, BaseCommand)
2165
2166
2167    def test_inheritedErrorPropagation(self):
2168        """
2169        Verify that errors specified in a superclass command are propagated to
2170        its subclasses.
2171        """
2172        return self.errorCheck(
2173            InheritedError, InheritedCommandProtocol, InheritedCommand)
2174
2175
2176    def test_inheritedErrorAddition(self):
2177        """
2178        Verify that new errors specified in a subclass of an existing command
2179        are honored even if the superclass defines some errors.
2180        """
2181        return self.errorCheck(
2182            OtherInheritedError, AddedCommandProtocol, AddErrorsCommand, other=True)
2183
2184
2185    def test_additionWithOriginalError(self):
2186        """
2187        Verify that errors specified in a command's superclass are respected
2188        even if that command defines new errors itself.
2189        """
2190        return self.errorCheck(
2191            InheritedError, AddedCommandProtocol, AddErrorsCommand, other=False)
2192
2193
2194def _loseAndPass(err, proto):
2195    # be specific, pass on the error to the client.
2196    err.trap(error.ConnectionLost, error.ConnectionDone)
2197    del proto.connectionLost
2198    proto.connectionLost(err)
2199
2200
2201class LiveFireBase:
2202    """
2203    Utility for connected reactor-using tests.
2204    """
2205
2206    def setUp(self):
2207        """
2208        Create an amp server and connect a client to it.
2209        """
2210        from twisted.internet import reactor
2211        self.serverFactory = protocol.ServerFactory()
2212        self.serverFactory.protocol = self.serverProto
2213        self.clientFactory = protocol.ClientFactory()
2214        self.clientFactory.protocol = self.clientProto
2215        self.clientFactory.onMade = defer.Deferred()
2216        self.serverFactory.onMade = defer.Deferred()
2217        self.serverPort = reactor.listenTCP(0, self.serverFactory)
2218        self.addCleanup(self.serverPort.stopListening)
2219        self.clientConn = reactor.connectTCP(
2220            '127.0.0.1', self.serverPort.getHost().port,
2221            self.clientFactory)
2222        self.addCleanup(self.clientConn.disconnect)
2223        def getProtos(rlst):
2224            self.cli = self.clientFactory.theProto
2225            self.svr = self.serverFactory.theProto
2226        dl = defer.DeferredList([self.clientFactory.onMade,
2227                                 self.serverFactory.onMade])
2228        return dl.addCallback(getProtos)
2229
2230    def tearDown(self):
2231        """
2232        Cleanup client and server connections, and check the error got at
2233        C{connectionLost}.
2234        """
2235        L = []
2236        for conn in self.cli, self.svr:
2237            if conn.transport is not None:
2238                # depend on amp's function connection-dropping behavior
2239                d = defer.Deferred().addErrback(_loseAndPass, conn)
2240                conn.connectionLost = d.errback
2241                conn.transport.loseConnection()
2242                L.append(d)
2243        return defer.gatherResults(L
2244            ).addErrback(lambda first: first.value.subFailure)
2245
2246
2247def show(x):
2248    import sys
2249    sys.stdout.write(x+'\n')
2250    sys.stdout.flush()
2251
2252
2253def tempSelfSigned():
2254    from twisted.internet import ssl
2255
2256    sharedDN = ssl.DN(CN='shared')
2257    key = ssl.KeyPair.generate()
2258    cr = key.certificateRequest(sharedDN)
2259    sscrd = key.signCertificateRequest(
2260        sharedDN, cr, lambda dn: True, 1234567)
2261    cert = key.newCertificate(sscrd)
2262    return cert
2263
2264if ssl is not None:
2265    tempcert = tempSelfSigned()
2266
2267
2268class LiveFireTLSTestCase(LiveFireBase, unittest.TestCase):
2269    clientProto = SecurableProto
2270    serverProto = SecurableProto
2271    def test_liveFireCustomTLS(self):
2272        """
2273        Using real, live TLS, actually negotiate a connection.
2274
2275        This also looks at the 'peerCertificate' attribute's correctness, since
2276        that's actually loaded using OpenSSL calls, but the main purpose is to
2277        make sure that we didn't miss anything obvious in iosim about TLS
2278        negotiations.
2279        """
2280
2281        cert = tempcert
2282
2283        self.svr.verifyFactory = lambda : [cert]
2284        self.svr.certFactory = lambda : cert
2285        # only needed on the server, we specify the client below.
2286
2287        def secured(rslt):
2288            x = cert.digest()
2289            def pinged(rslt2):
2290                # Interesting.  OpenSSL won't even _tell_ us about the peer
2291                # cert until we negotiate.  we should be able to do this in
2292                # 'secured' instead, but it looks like we can't.  I think this
2293                # is a bug somewhere far deeper than here.
2294                self.assertEqual(x, self.cli.hostCertificate.digest())
2295                self.assertEqual(x, self.cli.peerCertificate.digest())
2296                self.assertEqual(x, self.svr.hostCertificate.digest())
2297                self.assertEqual(x, self.svr.peerCertificate.digest())
2298            return self.cli.callRemote(SecuredPing).addCallback(pinged)
2299        return self.cli.callRemote(amp.StartTLS,
2300                                   tls_localCertificate=cert,
2301                                   tls_verifyAuthorities=[cert]).addCallback(secured)
2302
2303    skip = skipSSL
2304
2305
2306
2307class SlightlySmartTLS(SimpleSymmetricCommandProtocol):
2308    """
2309    Specific implementation of server side protocol with different
2310    management of TLS.
2311    """
2312    def getTLSVars(self):
2313        """
2314        @return: the global C{tempcert} certificate as local certificate.
2315        """
2316        return dict(tls_localCertificate=tempcert)
2317    amp.StartTLS.responder(getTLSVars)
2318
2319
2320class PlainVanillaLiveFire(LiveFireBase, unittest.TestCase):
2321
2322    clientProto = SimpleSymmetricCommandProtocol
2323    serverProto = SimpleSymmetricCommandProtocol
2324
2325    def test_liveFireDefaultTLS(self):
2326        """
2327        Verify that out of the box, we can start TLS to at least encrypt the
2328        connection, even if we don't have any certificates to use.
2329        """
2330        def secured(result):
2331            return self.cli.callRemote(SecuredPing)
2332        return self.cli.callRemote(amp.StartTLS).addCallback(secured)
2333
2334    skip = skipSSL
2335
2336
2337
2338class WithServerTLSVerification(LiveFireBase, unittest.TestCase):
2339    clientProto = SimpleSymmetricCommandProtocol
2340    serverProto = SlightlySmartTLS
2341
2342    def test_anonymousVerifyingClient(self):
2343        """
2344        Verify that anonymous clients can verify server certificates.
2345        """
2346        def secured(result):
2347            return self.cli.callRemote(SecuredPing)
2348        return self.cli.callRemote(amp.StartTLS,
2349                                   tls_verifyAuthorities=[tempcert]
2350            ).addCallback(secured)
2351
2352    skip = skipSSL
2353
2354
2355
2356class ProtocolIncludingArgument(amp.Argument):
2357    """
2358    An L{amp.Argument} which encodes its parser and serializer
2359    arguments *including the protocol* into its parsed and serialized
2360    forms.
2361    """
2362
2363    def fromStringProto(self, string, protocol):
2364        """
2365        Don't decode anything; just return all possible information.
2366
2367        @return: A two-tuple of the input string and the protocol.
2368        """
2369        return (string, protocol)
2370
2371    def toStringProto(self, obj, protocol):
2372        """
2373        Encode identifying information about L{object} and protocol
2374        into a string for later verification.
2375
2376        @type obj: L{object}
2377        @type protocol: L{amp.AMP}
2378        """
2379        return "%s:%s" % (id(obj), id(protocol))
2380
2381
2382
2383class ProtocolIncludingCommand(amp.Command):
2384    """
2385    A command that has argument and response schemas which use
2386    L{ProtocolIncludingArgument}.
2387    """
2388    arguments = [('weird', ProtocolIncludingArgument())]
2389    response = [('weird', ProtocolIncludingArgument())]
2390
2391
2392
2393class MagicSchemaCommand(amp.Command):
2394    """
2395    A command which overrides L{parseResponse}, L{parseArguments}, and
2396    L{makeResponse}.
2397    """
2398    def parseResponse(self, strings, protocol):
2399        """
2400        Don't do any parsing, just jam the input strings and protocol
2401        onto the C{protocol.parseResponseArguments} attribute as a
2402        two-tuple. Return the original strings.
2403        """
2404        protocol.parseResponseArguments = (strings, protocol)
2405        return strings
2406    parseResponse = classmethod(parseResponse)
2407
2408
2409    def parseArguments(cls, strings, protocol):
2410        """
2411        Don't do any parsing, just jam the input strings and protocol
2412        onto the C{protocol.parseArgumentsArguments} attribute as a
2413        two-tuple. Return the original strings.
2414        """
2415        protocol.parseArgumentsArguments = (strings, protocol)
2416        return strings
2417    parseArguments = classmethod(parseArguments)
2418
2419
2420    def makeArguments(cls, objects, protocol):
2421        """
2422        Don't do any serializing, just jam the input strings and protocol
2423        onto the C{protocol.makeArgumentsArguments} attribute as a
2424        two-tuple. Return the original strings.
2425        """
2426        protocol.makeArgumentsArguments = (objects, protocol)
2427        return objects
2428    makeArguments = classmethod(makeArguments)
2429
2430
2431
2432class NoNetworkProtocol(amp.AMP):
2433    """
2434    An L{amp.AMP} subclass which overrides private methods to avoid
2435    testing the network. It also provides a responder for
2436    L{MagicSchemaCommand} that does nothing, so that tests can test
2437    aspects of the interaction of L{amp.Command}s and L{amp.AMP}.
2438
2439    @ivar parseArgumentsArguments: Arguments that have been passed to any
2440    L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
2441    this protocol.
2442
2443    @ivar parseResponseArguments: Responses that have been returned from a
2444    L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
2445    this protocol.
2446
2447    @ivar makeArgumentsArguments: Arguments that have been serialized by any
2448    L{MagicSchemaCommand}, if L{MagicSchemaCommand} has been handled by
2449    this protocol.
2450    """
2451    def _sendBoxCommand(self, commandName, strings, requiresAnswer):
2452        """
2453        Return a Deferred which fires with the original strings.
2454        """
2455        return defer.succeed(strings)
2456
2457    MagicSchemaCommand.responder(lambda s, weird: {})
2458
2459
2460
2461class MyBox(dict):
2462    """
2463    A unique dict subclass.
2464    """
2465
2466
2467
2468class ProtocolIncludingCommandWithDifferentCommandType(
2469    ProtocolIncludingCommand):
2470    """
2471    A L{ProtocolIncludingCommand} subclass whose commandType is L{MyBox}
2472    """
2473    commandType = MyBox
2474
2475
2476
2477class CommandTestCase(unittest.TestCase):
2478    """
2479    Tests for L{amp.Argument} and L{amp.Command}.
2480    """
2481    def test_argumentInterface(self):
2482        """
2483        L{Argument} instances provide L{amp.IArgumentType}.
2484        """
2485        self.assertTrue(verifyObject(amp.IArgumentType, amp.Argument()))
2486
2487
2488    def test_parseResponse(self):
2489        """
2490        There should be a class method of Command which accepts a
2491        mapping of argument names to serialized forms and returns a
2492        similar mapping whose values have been parsed via the
2493        Command's response schema.
2494        """
2495        protocol = object()
2496        result = 'whatever'
2497        strings = {'weird': result}
2498        self.assertEqual(
2499            ProtocolIncludingCommand.parseResponse(strings, protocol),
2500            {'weird': (result, protocol)})
2501
2502
2503    def test_callRemoteCallsParseResponse(self):
2504        """
2505        Making a remote call on a L{amp.Command} subclass which
2506        overrides the C{parseResponse} method should call that
2507        C{parseResponse} method to get the response.
2508        """
2509        client = NoNetworkProtocol()
2510        thingy = "weeoo"
2511        response = client.callRemote(MagicSchemaCommand, weird=thingy)
2512        def gotResponse(ign):
2513            self.assertEqual(client.parseResponseArguments,
2514                              ({"weird": thingy}, client))
2515        response.addCallback(gotResponse)
2516        return response
2517
2518
2519    def test_parseArguments(self):
2520        """
2521        There should be a class method of L{amp.Command} which accepts
2522        a mapping of argument names to serialized forms and returns a
2523        similar mapping whose values have been parsed via the
2524        command's argument schema.
2525        """
2526        protocol = object()
2527        result = 'whatever'
2528        strings = {'weird': result}
2529        self.assertEqual(
2530            ProtocolIncludingCommand.parseArguments(strings, protocol),
2531            {'weird': (result, protocol)})
2532
2533
2534    def test_responderCallsParseArguments(self):
2535        """
2536        Making a remote call on a L{amp.Command} subclass which
2537        overrides the C{parseArguments} method should call that
2538        C{parseArguments} method to get the arguments.
2539        """
2540        protocol = NoNetworkProtocol()
2541        responder = protocol.locateResponder(MagicSchemaCommand.commandName)
2542        argument = object()
2543        response = responder(dict(weird=argument))
2544        response.addCallback(
2545            lambda ign: self.assertEqual(protocol.parseArgumentsArguments,
2546                                         ({"weird": argument}, protocol)))
2547        return response
2548
2549
2550    def test_makeArguments(self):
2551        """
2552        There should be a class method of L{amp.Command} which accepts
2553        a mapping of argument names to objects and returns a similar
2554        mapping whose values have been serialized via the command's
2555        argument schema.
2556        """
2557        protocol = object()
2558        argument = object()
2559        objects = {'weird': argument}
2560        self.assertEqual(
2561            ProtocolIncludingCommand.makeArguments(objects, protocol),
2562            {'weird': "%d:%d" % (id(argument), id(protocol))})
2563
2564
2565    def test_makeArgumentsUsesCommandType(self):
2566        """
2567        L{amp.Command.makeArguments}'s return type should be the type
2568        of the result of L{amp.Command.commandType}.
2569        """
2570        protocol = object()
2571        objects = {"weird": "whatever"}
2572
2573        result = ProtocolIncludingCommandWithDifferentCommandType.makeArguments(
2574            objects, protocol)
2575        self.assertIdentical(type(result), MyBox)
2576
2577
2578    def test_callRemoteCallsMakeArguments(self):
2579        """
2580        Making a remote call on a L{amp.Command} subclass which
2581        overrides the C{makeArguments} method should call that
2582        C{makeArguments} method to get the response.
2583        """
2584        client = NoNetworkProtocol()
2585        argument = object()
2586        response = client.callRemote(MagicSchemaCommand, weird=argument)
2587        def gotResponse(ign):
2588            self.assertEqual(client.makeArgumentsArguments,
2589                             ({"weird": argument}, client))
2590        response.addCallback(gotResponse)
2591        return response
2592
2593
2594    def test_extraArgumentsDisallowed(self):
2595        """
2596        L{Command.makeArguments} raises L{amp.InvalidSignature} if the objects
2597        dictionary passed to it includes a key which does not correspond to the
2598        Python identifier for a defined argument.
2599        """
2600        self.assertRaises(
2601            amp.InvalidSignature,
2602            Hello.makeArguments,
2603            dict(hello="hello", bogusArgument=object()), None)
2604
2605
2606    def test_wireSpellingDisallowed(self):
2607        """
2608        If a command argument conflicts with a Python keyword, the
2609        untransformed argument name is not allowed as a key in the dictionary
2610        passed to L{Command.makeArguments}.  If it is supplied,
2611        L{amp.InvalidSignature} is raised.
2612
2613        This may be a pointless implementation restriction which may be lifted.
2614        The current behavior is tested to verify that such arguments are not
2615        silently dropped on the floor (the previous behavior).
2616        """
2617        self.assertRaises(
2618            amp.InvalidSignature,
2619            Hello.makeArguments,
2620            dict(hello="required", **{"print": "print value"}),
2621            None)
2622
2623
2624class ListOfTestsMixin:
2625    """
2626    Base class for testing L{ListOf}, a parameterized zero-or-more argument
2627    type.
2628
2629    @ivar elementType: Subclasses should set this to an L{Argument}
2630        instance.  The tests will make a L{ListOf} using this.
2631
2632    @ivar strings: Subclasses should set this to a dictionary mapping some
2633        number of keys to the correct serialized form for some example
2634        values.  These should agree with what L{elementType}
2635        produces/accepts.
2636
2637    @ivar objects: Subclasses should set this to a dictionary with the same
2638        keys as C{strings} and with values which are the lists which should
2639        serialize to the values in the C{strings} dictionary.
2640    """
2641    def test_toBox(self):
2642        """
2643        L{ListOf.toBox} extracts the list of objects from the C{objects}
2644        dictionary passed to it, using the C{name} key also passed to it,
2645        serializes each of the elements in that list using the L{Argument}
2646        instance previously passed to its initializer, combines the serialized
2647        results, and inserts the result into the C{strings} dictionary using
2648        the same C{name} key.
2649        """
2650        stringList = amp.ListOf(self.elementType)
2651        strings = amp.AmpBox()
2652        for key in self.objects:
2653            stringList.toBox(key, strings, self.objects.copy(), None)
2654        self.assertEqual(strings, self.strings)
2655
2656
2657    def test_fromBox(self):
2658        """
2659        L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}.
2660        """
2661        stringList = amp.ListOf(self.elementType)
2662        objects = {}
2663        for key in self.strings:
2664            stringList.fromBox(key, self.strings.copy(), objects, None)
2665        self.assertEqual(objects, self.objects)
2666
2667
2668
2669class ListOfStringsTests(unittest.TestCase, ListOfTestsMixin):
2670    """
2671    Tests for L{ListOf} combined with L{amp.String}.
2672    """
2673    elementType = amp.String()
2674
2675    strings = {
2676        "empty": "",
2677        "single": "\x00\x03foo",
2678        "multiple": "\x00\x03bar\x00\x03baz\x00\x04quux"}
2679
2680    objects = {
2681        "empty": [],
2682        "single": ["foo"],
2683        "multiple": ["bar", "baz", "quux"]}
2684
2685
2686class ListOfIntegersTests(unittest.TestCase, ListOfTestsMixin):
2687    """
2688    Tests for L{ListOf} combined with L{amp.Integer}.
2689    """
2690    elementType = amp.Integer()
2691
2692    huge = (
2693        9999999999999999999999999999999999999999999999999999999999 *
2694        9999999999999999999999999999999999999999999999999999999999)
2695
2696    strings = {
2697        "empty": "",
2698        "single": "\x00\x0210",
2699        "multiple": "\x00\x011\x00\x0220\x00\x03500",
2700        "huge": "\x00\x74%d" % (huge,),
2701        "negative": "\x00\x02-1"}
2702
2703    objects = {
2704        "empty": [],
2705        "single": [10],
2706        "multiple": [1, 20, 500],
2707        "huge": [huge],
2708        "negative": [-1]}
2709
2710
2711
2712class ListOfUnicodeTests(unittest.TestCase, ListOfTestsMixin):
2713    """
2714    Tests for L{ListOf} combined with L{amp.Unicode}.
2715    """
2716    elementType = amp.Unicode()
2717
2718    strings = {
2719        "empty": "",
2720        "single": "\x00\x03foo",
2721        "multiple": "\x00\x03\xe2\x98\x83\x00\x05Hello\x00\x05world"}
2722
2723    objects = {
2724        "empty": [],
2725        "single": [u"foo"],
2726        "multiple": [u"\N{SNOWMAN}", u"Hello", u"world"]}
2727
2728
2729
2730class ListOfDecimalTests(unittest.TestCase, ListOfTestsMixin):
2731    """
2732    Tests for L{ListOf} combined with L{amp.Decimal}.
2733    """
2734    elementType = amp.Decimal()
2735
2736    strings = {
2737        "empty": "",
2738        "single": "\x00\x031.1",
2739        "extreme": "\x00\x08Infinity\x00\x09-Infinity",
2740        "scientist": "\x00\x083.141E+5\x00\x0a0.00003141\x00\x083.141E-7"
2741                     "\x00\x09-3.141E+5\x00\x0b-0.00003141\x00\x09-3.141E-7",
2742        "engineer": "\x00\x04%s\x00\x06%s" % (
2743            decimal.Decimal("0e6").to_eng_string(),
2744            decimal.Decimal("1.5E-9").to_eng_string()),
2745    }
2746
2747    objects = {
2748        "empty": [],
2749        "single": [decimal.Decimal("1.1")],
2750        "extreme": [
2751            decimal.Decimal("Infinity"),
2752            decimal.Decimal("-Infinity"),
2753        ],
2754        # exarkun objected to AMP supporting engineering notation because
2755        # it was redundant, until we realised that 1E6 has less precision
2756        # than 1000000 and is represented differently.  But they compare
2757        # and even hash equally.  There were tears.
2758        "scientist": [
2759            decimal.Decimal("3.141E5"),
2760            decimal.Decimal("3.141e-5"),
2761            decimal.Decimal("3.141E-7"),
2762            decimal.Decimal("-3.141e5"),
2763            decimal.Decimal("-3.141E-5"),
2764            decimal.Decimal("-3.141e-7"),
2765        ],
2766        "engineer": [
2767            decimal.Decimal("0e6"),
2768            decimal.Decimal("1.5E-9"),
2769        ],
2770     }
2771
2772
2773
2774class ListOfDecimalNanTests(unittest.TestCase, ListOfTestsMixin):
2775    """
2776    Tests for L{ListOf} combined with L{amp.Decimal} for not-a-number values.
2777    """
2778    elementType = amp.Decimal()
2779
2780    strings = {
2781        "nan": "\x00\x03NaN\x00\x04-NaN\x00\x04sNaN\x00\x05-sNaN",
2782    }
2783
2784    objects = {
2785        "nan": [
2786            decimal.Decimal("NaN"),
2787            decimal.Decimal("-NaN"),
2788            decimal.Decimal("sNaN"),
2789            decimal.Decimal("-sNaN"),
2790        ]
2791    }
2792
2793    def test_fromBox(self):
2794        """
2795        L{ListOf.fromBox} reverses the operation performed by L{ListOf.toBox}.
2796        """
2797        # Helpers.  Decimal.is_{qnan,snan,signed}() are new in 2.6 (or 2.5.2,
2798        # but who's counting).
2799        def is_qnan(decimal):
2800            return 'NaN' in str(decimal) and 'sNaN' not in str(decimal)
2801
2802        def is_snan(decimal):
2803            return 'sNaN' in str(decimal)
2804
2805        def is_signed(decimal):
2806            return '-' in str(decimal)
2807
2808        # NaN values have unusual equality semantics, so this method is
2809        # overridden to compare the resulting objects in a way which works with
2810        # NaNs.
2811        stringList = amp.ListOf(self.elementType)
2812        objects = {}
2813        for key in self.strings:
2814            stringList.fromBox(key, self.strings.copy(), objects, None)
2815        n = objects["nan"]
2816        self.assertTrue(is_qnan(n[0]) and not is_signed(n[0]))
2817        self.assertTrue(is_qnan(n[1]) and is_signed(n[1]))
2818        self.assertTrue(is_snan(n[2]) and not is_signed(n[2]))
2819        self.assertTrue(is_snan(n[3]) and is_signed(n[3]))
2820
2821
2822
2823class DecimalTests(unittest.TestCase):
2824    """
2825    Tests for L{amp.Decimal}.
2826    """
2827    def test_nonDecimal(self):
2828        """
2829        L{amp.Decimal.toString} raises L{ValueError} if passed an object which
2830        is not an instance of C{decimal.Decimal}.
2831        """
2832        argument = amp.Decimal()
2833        self.assertRaises(ValueError, argument.toString, "1.234")
2834        self.assertRaises(ValueError, argument.toString, 1.234)
2835        self.assertRaises(ValueError, argument.toString, 1234)
2836
2837
2838
2839class ListOfDateTimeTests(unittest.TestCase, ListOfTestsMixin):
2840    """
2841    Tests for L{ListOf} combined with L{amp.DateTime}.
2842    """
2843    elementType = amp.DateTime()
2844
2845    strings = {
2846        "christmas":
2847            "\x00\x202010-12-25T00:00:00.000000-00:00"
2848            "\x00\x202010-12-25T00:00:00.000000-00:00",
2849        "christmas in eu": "\x00\x202010-12-25T00:00:00.000000+01:00",
2850        "christmas in iran": "\x00\x202010-12-25T00:00:00.000000+03:30",
2851        "christmas in nyc": "\x00\x202010-12-25T00:00:00.000000-05:00",
2852        "previous tests": "\x00\x202010-12-25T00:00:00.000000+03:19"
2853                          "\x00\x202010-12-25T00:00:00.000000-06:59",
2854    }
2855
2856    objects = {
2857        "christmas": [
2858            datetime.datetime(2010, 12, 25, 0, 0, 0, tzinfo=amp.utc),
2859            datetime.datetime(2010, 12, 25, 0, 0, 0,
2860                tzinfo=amp._FixedOffsetTZInfo('+', 0, 0)),
2861        ],
2862        "christmas in eu": [
2863            datetime.datetime(2010, 12, 25, 0, 0, 0,
2864                tzinfo=amp._FixedOffsetTZInfo('+', 1, 0)),
2865        ],
2866        "christmas in iran": [
2867            datetime.datetime(2010, 12, 25, 0, 0, 0,
2868                tzinfo=amp._FixedOffsetTZInfo('+', 3, 30)),
2869        ],
2870        "christmas in nyc": [
2871            datetime.datetime(2010, 12, 25, 0, 0, 0,
2872                tzinfo=amp._FixedOffsetTZInfo('-', 5, 0)),
2873        ],
2874        "previous tests": [
2875            datetime.datetime(2010, 12, 25, 0, 0, 0,
2876                tzinfo=amp._FixedOffsetTZInfo('+', 3, 19)),
2877            datetime.datetime(2010, 12, 25, 0, 0, 0,
2878                tzinfo=amp._FixedOffsetTZInfo('-', 6, 59)),
2879        ],
2880    }
2881
2882
2883
2884class ListOfOptionalTests(unittest.TestCase):
2885    """
2886    Tests to ensure L{ListOf} AMP arguments can be omitted from AMP commands
2887    via the 'optional' flag.
2888    """
2889    def test_requiredArgumentWithNoneValueRaisesTypeError(self):
2890        """
2891        L{ListOf.toBox} raises C{TypeError} when passed a value of C{None}
2892        for the argument.
2893        """
2894        stringList = amp.ListOf(amp.Integer())
2895        self.assertRaises(
2896            TypeError, stringList.toBox, 'omitted', amp.AmpBox(),
2897            {'omitted': None}, None)
2898
2899
2900    def test_optionalArgumentWithNoneValueOmitted(self):
2901        """
2902        L{ListOf.toBox} silently omits serializing any argument with a
2903        value of C{None} that is designated as optional for the protocol.
2904        """
2905        stringList = amp.ListOf(amp.Integer(), optional=True)
2906        strings = amp.AmpBox()
2907        stringList.toBox('omitted', strings, {'omitted': None}, None)
2908        self.assertEqual(strings, {})
2909
2910
2911    def test_requiredArgumentWithKeyMissingRaisesKeyError(self):
2912        """
2913        L{ListOf.toBox} raises C{KeyError} if the argument's key is not
2914        present in the objects dictionary.
2915        """
2916        stringList = amp.ListOf(amp.Integer())
2917        self.assertRaises(
2918            KeyError, stringList.toBox, 'ommited', amp.AmpBox(),
2919            {'someOtherKey': 0}, None)
2920
2921
2922    def test_optionalArgumentWithKeyMissingOmitted(self):
2923        """
2924        L{ListOf.toBox} silently omits serializing any argument designated
2925        as optional whose key is not present in the objects dictionary.
2926        """
2927        stringList = amp.ListOf(amp.Integer(), optional=True)
2928        stringList.toBox('ommited', amp.AmpBox(), {'someOtherKey': 0}, None)
2929
2930
2931    def test_omittedOptionalArgumentDeserializesAsNone(self):
2932        """
2933        L{ListOf.fromBox} correctly reverses the operation performed by
2934        L{ListOf.toBox} for optional arguments.
2935        """
2936        stringList = amp.ListOf(amp.Integer(), optional=True)
2937        objects = {}
2938        stringList.fromBox('omitted', {}, objects, None)
2939        self.assertEqual(objects, {'omitted': None})
2940
2941
2942
2943class UNIXStringTransport(object):
2944    """
2945    An in-memory implementation of L{interfaces.IUNIXTransport} which collects
2946    all data given to it for later inspection.
2947
2948    @ivar _queue: A C{list} of the data which has been given to this transport,
2949        eg via C{write} or C{sendFileDescriptor}.  Elements are two-tuples of a
2950        string (identifying the destination of the data) and the data itself.
2951    """
2952    implements(interfaces.IUNIXTransport)
2953
2954    def __init__(self, descriptorFuzz):
2955        """
2956        @param descriptorFuzz: An offset to apply to descriptors.
2957        @type descriptorFuzz: C{int}
2958        """
2959        self._fuzz = descriptorFuzz
2960        self._queue = []
2961
2962
2963    def sendFileDescriptor(self, descriptor):
2964        self._queue.append((
2965                'fileDescriptorReceived', descriptor + self._fuzz))
2966
2967
2968    def write(self, data):
2969        self._queue.append(('dataReceived', data))
2970
2971
2972    def writeSequence(self, seq):
2973        for data in seq:
2974            self.write(data)
2975
2976
2977    def loseConnection(self):
2978        self._queue.append(('connectionLost', Failure(ConnectionLost())))
2979
2980
2981    def getHost(self):
2982        return UNIXAddress('/tmp/some-path')
2983
2984
2985    def getPeer(self):
2986        return UNIXAddress('/tmp/another-path')
2987
2988# Minimal evidence that we got the signatures right
2989verifyClass(interfaces.ITransport, UNIXStringTransport)
2990verifyClass(interfaces.IUNIXTransport, UNIXStringTransport)
2991
2992
2993class DescriptorTests(unittest.TestCase):
2994    """
2995    Tests for L{amp.Descriptor}, an argument type for passing a file descriptor
2996    over an AMP connection over a UNIX domain socket.
2997    """
2998    def setUp(self):
2999        self.fuzz = 3
3000        self.transport = UNIXStringTransport(descriptorFuzz=self.fuzz)
3001        self.protocol = amp.BinaryBoxProtocol(
3002            amp.BoxDispatcher(amp.CommandLocator()))
3003        self.protocol.makeConnection(self.transport)
3004
3005
3006    def test_fromStringProto(self):
3007        """
3008        L{Descriptor.fromStringProto} constructs a file descriptor value by
3009        extracting a previously received file descriptor corresponding to the
3010        wire value of the argument from the L{_DescriptorExchanger} state of the
3011        protocol passed to it.
3012
3013        This is a whitebox test which involves direct L{_DescriptorExchanger}
3014        state inspection.
3015        """
3016        argument = amp.Descriptor()
3017        self.protocol.fileDescriptorReceived(5)
3018        self.protocol.fileDescriptorReceived(3)
3019        self.protocol.fileDescriptorReceived(1)
3020        self.assertEqual(
3021            5, argument.fromStringProto("0", self.protocol))
3022        self.assertEqual(
3023            3, argument.fromStringProto("1", self.protocol))
3024        self.assertEqual(
3025            1, argument.fromStringProto("2", self.protocol))
3026        self.assertEqual({}, self.protocol._descriptors)
3027
3028
3029    def test_toStringProto(self):
3030        """
3031        To send a file descriptor, L{Descriptor.toStringProto} uses the
3032        L{IUNIXTransport.sendFileDescriptor} implementation of the transport of
3033        the protocol passed to it to copy the file descriptor.  Each subsequent
3034        descriptor sent over a particular AMP connection is assigned the next
3035        integer value, starting from 0.  The base ten string representation of
3036        this value is the byte encoding of the argument.
3037
3038        This is a whitebox test which involves direct L{_DescriptorExchanger}
3039        state inspection and mutation.
3040        """
3041        argument = amp.Descriptor()
3042        self.assertEqual("0", argument.toStringProto(2, self.protocol))
3043        self.assertEqual(
3044            ("fileDescriptorReceived", 2 + self.fuzz), self.transport._queue.pop(0))
3045        self.assertEqual("1", argument.toStringProto(4, self.protocol))
3046        self.assertEqual(
3047            ("fileDescriptorReceived", 4 + self.fuzz), self.transport._queue.pop(0))
3048        self.assertEqual("2", argument.toStringProto(6, self.protocol))
3049        self.assertEqual(
3050            ("fileDescriptorReceived", 6 + self.fuzz), self.transport._queue.pop(0))
3051        self.assertEqual({}, self.protocol._descriptors)
3052
3053
3054    def test_roundTrip(self):
3055        """
3056        L{amp.Descriptor.fromBox} can interpret an L{amp.AmpBox} constructed by
3057        L{amp.Descriptor.toBox} to reconstruct a file descriptor value.
3058        """
3059        name = "alpha"
3060        strings = {}
3061        descriptor = 17
3062        sendObjects = {name: descriptor}
3063
3064        argument = amp.Descriptor()
3065        argument.toBox(name, strings, sendObjects.copy(), self.protocol)
3066
3067        receiver = amp.BinaryBoxProtocol(
3068            amp.BoxDispatcher(amp.CommandLocator()))
3069        for event in self.transport._queue:
3070            getattr(receiver, event[0])(*event[1:])
3071
3072        receiveObjects = {}
3073        argument.fromBox(name, strings.copy(), receiveObjects, receiver)
3074
3075        # Make sure we got the descriptor.  Adjust by fuzz to be more convincing
3076        # of having gone through L{IUNIXTransport.sendFileDescriptor}, not just
3077        # converted to a string and then parsed back into an integer.
3078        self.assertEqual(descriptor + self.fuzz, receiveObjects[name])
3079
3080
3081
3082class DateTimeTests(unittest.TestCase):
3083    """
3084    Tests for L{amp.DateTime}, L{amp._FixedOffsetTZInfo}, and L{amp.utc}.
3085    """
3086    string = '9876-01-23T12:34:56.054321-01:23'
3087    tzinfo = amp._FixedOffsetTZInfo('-', 1, 23)
3088    object = datetime.datetime(9876, 1, 23, 12, 34, 56, 54321, tzinfo)
3089
3090    def test_invalidString(self):
3091        """
3092        L{amp.DateTime.fromString} raises L{ValueError} when passed a string
3093        which does not represent a timestamp in the proper format.
3094        """
3095        d = amp.DateTime()
3096        self.assertRaises(ValueError, d.fromString, 'abc')
3097
3098
3099    def test_invalidDatetime(self):
3100        """
3101        L{amp.DateTime.toString} raises L{ValueError} when passed a naive
3102        datetime (a datetime with no timezone information).
3103        """
3104        d = amp.DateTime()
3105        self.assertRaises(ValueError, d.toString,
3106            datetime.datetime(2010, 12, 25, 0, 0, 0))
3107
3108
3109    def test_fromString(self):
3110        """
3111        L{amp.DateTime.fromString} returns a C{datetime.datetime} with all of
3112        its fields populated from the string passed to it.
3113        """
3114        argument = amp.DateTime()
3115        value = argument.fromString(self.string)
3116        self.assertEqual(value, self.object)
3117
3118
3119    def test_toString(self):
3120        """
3121        L{amp.DateTime.toString} returns a C{str} in the wire format including
3122        all of the information from the C{datetime.datetime} passed into it,
3123        including the timezone offset.
3124        """
3125        argument = amp.DateTime()
3126        value = argument.toString(self.object)
3127        self.assertEqual(value, self.string)
3128
3129
3130
3131class FixedOffsetTZInfoTests(unittest.TestCase):
3132    """
3133    Tests for L{amp._FixedOffsetTZInfo} and L{amp.utc}.
3134    """
3135
3136    def test_tzname(self):
3137        """
3138        L{amp.utc.tzname} returns C{"+00:00"}.
3139        """
3140        self.assertEqual(amp.utc.tzname(None), '+00:00')
3141
3142
3143    def test_dst(self):
3144        """
3145        L{amp.utc.dst} returns a zero timedelta.
3146        """
3147        self.assertEqual(amp.utc.dst(None), datetime.timedelta(0))
3148
3149
3150    def test_utcoffset(self):
3151        """
3152        L{amp.utc.utcoffset} returns a zero timedelta.
3153        """
3154        self.assertEqual(amp.utc.utcoffset(None), datetime.timedelta(0))
3155
3156
3157    def test_badSign(self):
3158        """
3159        L{amp._FixedOffsetTZInfo} raises L{ValueError} if passed an offset sign
3160        other than C{'+'} or C{'-'}.
3161        """
3162        self.assertRaises(ValueError, amp._FixedOffsetTZInfo, '?', 0, 0)
3163
3164
3165
3166if not interfaces.IReactorSSL.providedBy(reactor):
3167    skipMsg = 'This test case requires SSL support in the reactor'
3168    TLSTest.skip = skipMsg
3169    LiveFireTLSTestCase.skip = skipMsg
3170    PlainVanillaLiveFire.skip = skipMsg
3171    WithServerTLSVerification.skip = skipMsg
3172