1# Copyright 2005 Divmod, Inc.  See LICENSE file for details
2
3import six
4
5from epsilon import juice
6from epsilon.test import iosim
7from twisted.trial import unittest
8from twisted.internet import protocol, defer
9
10class TestProto(protocol.Protocol):
11    def __init__(self, onConnLost, dataToSend):
12        self.onConnLost = onConnLost
13        self.dataToSend = dataToSend
14
15    def connectionMade(self):
16        self.data = []
17        self.transport.write(self.dataToSend)
18
19    def dataReceived(self, bytes):
20        self.data.append(bytes)
21        self.transport.loseConnection()
22
23    def connectionLost(self, reason):
24        self.onConnLost.callback(self.data)
25
26class SimpleSymmetricProtocol(juice.Juice):
27
28    def sendHello(self, text):
29        return self.sendCommand("hello",
30                                hello=text)
31
32    def sendGoodbye(self):
33        return self.sendCommand("goodbye")
34
35    def juice_HELLO(self, box):
36        return juice.Box(hello=box['hello'])
37
38    def juice_GOODBYE(self, box):
39        return juice.QuitBox(goodbye='world')
40
41class UnfriendlyGreeting(Exception):
42    """Greeting was insufficiently kind.
43    """
44
45class UnknownProtocol(Exception):
46    """Asked to switch to the wrong protocol.
47    """
48
49class Hello(juice.Command):
50    commandName = 'hello'
51    arguments = [('hello', juice.String())]
52    response = [('hello', juice.String())]
53
54    errors = {UnfriendlyGreeting: 'UNFRIENDLY'}
55
56class Goodbye(juice.Command):
57    commandName = 'goodbye'
58    responseType = juice.QuitBox
59
60class GetList(juice.Command):
61    commandName = 'getlist'
62    arguments = [('length', juice.Integer())]
63    response = [('body', juice.JuiceList([('x', juice.Integer())]))]
64
65class TestSwitchProto(juice.ProtocolSwitchCommand):
66    commandName = 'Switch-Proto'
67
68    arguments = [
69        ('name', juice.String()),
70        ]
71    errors = {UnknownProtocol: 'UNKNOWN'}
72
73class SingleUseFactory(protocol.ClientFactory):
74    def __init__(self, proto):
75        self.proto = proto
76
77    def buildProtocol(self, addr):
78        p, self.proto = self.proto, None
79        return p
80
81class SimpleSymmetricCommandProtocol(juice.Juice):
82    maybeLater = None
83    def __init__(self, issueGreeting, onConnLost=None):
84        juice.Juice.__init__(self, issueGreeting)
85        self.onConnLost = onConnLost
86
87    def sendHello(self, text):
88        return Hello(hello=text).do(self)
89    def sendGoodbye(self):
90        return Goodbye().do(self)
91    def command_HELLO(self, hello):
92        if hello.startswith('fuck'):
93            raise UnfriendlyGreeting("Don't be a dick.")
94        return dict(hello=hello)
95    def command_GETLIST(self, length):
96        return {'body': [dict(x=1)] * length}
97    def command_GOODBYE(self):
98        return dict(goodbye='world')
99    command_HELLO.command = Hello
100    command_GOODBYE.command = Goodbye
101    command_GETLIST.command = GetList
102
103    def switchToTestProtocol(self):
104        p = TestProto(self.onConnLost, SWITCH_CLIENT_DATA)
105        return TestSwitchProto(SingleUseFactory(p), name='test-proto').do(self).addCallback(lambda ign: p)
106
107    def command_SWITCH_PROTO(self, name):
108        if name == 'test-proto':
109            return TestProto(self.onConnLost, SWITCH_SERVER_DATA)
110        raise UnknownProtocol(name)
111
112    command_SWITCH_PROTO.command = TestSwitchProto
113
114class DeferredSymmetricCommandProtocol(SimpleSymmetricCommandProtocol):
115    def command_SWITCH_PROTO(self, name):
116        if name == 'test-proto':
117            self.maybeLaterProto = TestProto(self.onConnLost, SWITCH_SERVER_DATA)
118            self.maybeLater = defer.Deferred()
119            return self.maybeLater
120        raise UnknownProtocol(name)
121
122    command_SWITCH_PROTO.command = TestSwitchProto
123
124
125class SSPF: protocol = SimpleSymmetricProtocol
126class SSSF(SSPF, protocol.ServerFactory): pass
127class SSCF(SSPF, protocol.ClientFactory): pass
128
129def connectedServerAndClient(ServerClass=lambda: SimpleSymmetricProtocol(True),
130                             ClientClass=lambda: SimpleSymmetricProtocol(False),
131                             *a, **kw):
132    """Returns a 3-tuple: (client, server, pump)
133    """
134    return iosim.connectedServerAndClient(
135        ServerClass, ClientClass,
136        *a, **kw)
137
138class TotallyDumbProtocol(protocol.Protocol):
139    buf = b''
140    def dataReceived(self, data):
141        self.buf += data
142
143class LiteralJuice(juice.Juice):
144    def __init__(self, issueGreeting):
145        juice.Juice.__init__(self, issueGreeting)
146        self.boxes = []
147
148    def juiceBoxReceived(self, box):
149        self.boxes.append(box)
150        return
151
152class LiteralParsingTest(unittest.TestCase):
153    def testBasicRequestResponse(self):
154        c, s, p = connectedServerAndClient(ClientClass=TotallyDumbProtocol)
155        HELLO = b'abcdefg'
156        ASKTOK = b'hand-crafted-ask'
157        c.transport.write((b"""-Command: HeLlO
158-Ask: %s
159Hello: %s
160World: this header is ignored
161
162""" % (ASKTOK, HELLO,)).replace(b'\n',b'\r\n'))
163        p.flush()
164        asserts = {'hello': HELLO,
165                   '-answer': ASKTOK}
166        hdrs = [j.split(b': ') for j in c.buf.split(b'\r\n')[:-2]]
167        self.assertEquals(len(asserts), len(hdrs))
168        for hdr in hdrs:
169            k, v = hdr
170            self.assertEquals(v, asserts[six.ensure_str(k).lower()])
171
172    def testParsingRoundTrip(self):
173        c, s, p = connectedServerAndClient(ClientClass=lambda: LiteralJuice(False),
174                                           ServerClass=lambda: LiteralJuice(True))
175
176        SIMPLE = ('simple', 'test')
177        CE = ('ceq', ': ')
178        CR = ('crtest', 'test\r')
179        LF = ('lftest', 'hello\n')
180        NEWLINE = ('newline', 'test\r\none\r\ntwo')
181        NEWLINE2 = ('newline2', 'test\r\none\r\n two')
182        BLANKLINE = ('newline3', 'test\r\n\r\nblank\r\n\r\nline')
183        BODYTEST = (juice.BODY, 'blah\r\n\r\ntesttest')
184
185        def onetest(test):
186            jb = juice.Box()
187            jb.update(dict(test))
188            jb.sendTo(c)
189            p.flush()
190            self.assertEquals(s.boxes[-1], jb)
191
192        onetest([SIMPLE])
193        onetest([SIMPLE, BODYTEST])
194        onetest([SIMPLE, CE])
195        onetest([SIMPLE, CR])
196        onetest([SIMPLE, CE, CR, LF])
197        onetest([CE, CR, LF])
198        onetest([SIMPLE, NEWLINE, CE, NEWLINE2])
199        onetest([BODYTEST, SIMPLE, NEWLINE])
200
201
202SWITCH_CLIENT_DATA = b'Success!'
203SWITCH_SERVER_DATA = b'No, really.  Success.'
204
205class AppLevelTest(unittest.TestCase):
206    def testHelloWorld(self):
207        c, s, p = connectedServerAndClient()
208        L = []
209        HELLO = 'world'
210        c.sendHello(HELLO).addCallback(L.append)
211        p.flush()
212        self.assertEquals(L[0]['hello'], HELLO)
213
214    def testHelloWorldCommand(self):
215        c, s, p = connectedServerAndClient(
216            ServerClass=lambda: SimpleSymmetricCommandProtocol(True),
217            ClientClass=lambda: SimpleSymmetricCommandProtocol(False))
218        L = []
219        HELLO = 'world'
220        c.sendHello(HELLO).addCallback(L.append)
221        p.flush()
222        self.assertEquals(L[0]['hello'], HELLO)
223
224    def testHelloErrorHandling(self):
225        L=[]
226        c, s, p = connectedServerAndClient(ServerClass=lambda: SimpleSymmetricCommandProtocol(True),
227                                           ClientClass=lambda: SimpleSymmetricCommandProtocol(False))
228        HELLO = 'fuck you'
229        c.sendHello(HELLO).addErrback(L.append)
230        p.flush()
231        L[0].trap(UnfriendlyGreeting)
232        self.assertEquals(str(L[0].value), "Don't be a dick.")
233
234    def testJuiceListCommand(self):
235        c, s, p = connectedServerAndClient(ServerClass=lambda: SimpleSymmetricCommandProtocol(True),
236                                           ClientClass=lambda: SimpleSymmetricCommandProtocol(False))
237        L = []
238        GetList(length=10).do(c).addCallback(L.append)
239        p.flush()
240        values = L.pop().get('body')
241        self.assertEquals(values, [{'x': 1}] * 10)
242
243    def testFailEarlyOnArgSending(self):
244        okayCommand = Hello(Hello="What?")
245        self.assertRaises(RuntimeError, Hello)
246
247    def testSupportsVersion1(self):
248        c, s, p = connectedServerAndClient(ServerClass=lambda: juice.Juice(True),
249                                           ClientClass=lambda: juice.Juice(False))
250        negotiatedVersion = []
251        s.renegotiateVersion(1).addCallback(negotiatedVersion.append)
252        p.flush()
253        self.assertEquals(negotiatedVersion[0], 1)
254        self.assertEquals(c.protocolVersion, 1)
255        self.assertEquals(s.protocolVersion, 1)
256
257    def testProtocolSwitch(self, switcher=SimpleSymmetricCommandProtocol):
258        self.testSucceeded = False
259
260        serverDeferred = defer.Deferred()
261        serverProto = switcher(True, serverDeferred)
262        clientDeferred = defer.Deferred()
263        clientProto = switcher(False, clientDeferred)
264        c, s, p = connectedServerAndClient(ServerClass=lambda: serverProto,
265                                           ClientClass=lambda: clientProto)
266
267        switchDeferred = c.switchToTestProtocol()
268
269        def cbConnsLost(results):
270            (serverSuccess, serverData), (clientSuccess, clientData) = results
271            self.failUnless(serverSuccess)
272            self.failUnless(clientSuccess)
273            self.assertEquals(b''.join(serverData), SWITCH_CLIENT_DATA)
274            self.assertEquals(b''.join(clientData), SWITCH_SERVER_DATA)
275            self.testSucceeded = True
276
277        def cbSwitch(proto):
278            return defer.DeferredList([serverDeferred, clientDeferred]).addCallback(cbConnsLost)
279
280        switchDeferred.addCallback(cbSwitch)
281        p.flush()
282        if serverProto.maybeLater is not None:
283            serverProto.maybeLater.callback(serverProto.maybeLaterProto)
284            p.flush()
285        self.failUnless(self.testSucceeded)
286
287    def testProtocolSwitchDeferred(self):
288        return self.testProtocolSwitch(switcher=DeferredSymmetricCommandProtocol)
289