1# -*- test-case-name: vertex.test.test_ptcp -*-
2
3import random, os
4
5from twisted.internet import reactor, protocol, defer, error
6from twisted.trial import unittest
7
8from vertex import ptcp
9
10def reallyLossy(method):
11    r = random.Random()
12    r.seed(42)
13    def worseMethod(*a, **kw):
14        if r.choice([True, True, False]):
15            method(*a, **kw)
16    return worseMethod
17
18def insufficientTransmitter(method,  mtu):
19    def worseMethod(bytes, addr):
20        method(bytes[:mtu], addr)
21    return worseMethod
22
23
24class TestProtocol(protocol.Protocol):
25    buffer = None
26    def __init__(self):
27        self.onConnect = defer.Deferred()
28        self.onDisconn = defer.Deferred()
29        self._waiting = None
30        self.buffer = []
31
32    def connectionMade(self):
33        self.onConnect.callback(None)
34
35    def connectionLost(self, reason):
36        self.onDisconn.callback(None)
37
38    def gotBytes(self, bytes):
39        assert self._waiting is None
40        if ''.join(self.buffer) == bytes:
41            return defer.succeed(None)
42        self._waiting = (defer.Deferred(), bytes)
43        return self._waiting[0]
44
45    def dataReceived(self, bytes):
46        self.buffer.append(bytes)
47        if self._waiting is not None:
48            bytes = ''.join(self.buffer)
49            if not self._waiting[1].startswith(bytes):
50                x = len(os.path.commonprefix([bytes, self._waiting[1]]))
51                print x
52                print 'it goes wrong starting with', repr(bytes[x:x+100]), repr(self._waiting[1][x:x+100])
53            if bytes == self._waiting[1]:
54                self._waiting[0].callback(None)
55                self._waiting = None
56
57class Django(protocol.ClientFactory):
58    def __init__(self):
59        self.onConnect = defer.Deferred()
60
61    def buildProtocol(self, addr):
62        p = protocol.ClientFactory.buildProtocol(self, addr)
63        self.onConnect.callback(p)
64        return p
65
66    def clientConnectionFailed(self, conn, err):
67        self.onConnect.errback(err)
68
69class ConnectedPTCPMixin:
70    serverPort = None
71
72    def setUpForATest(self,
73                      ServerProtocol=TestProtocol, ClientProtocol=TestProtocol):
74        serverProto = ServerProtocol()
75        clientProto = ClientProtocol()
76
77
78        self.serverProto = serverProto
79        self.clientProto = clientProto
80
81        sf = protocol.ServerFactory()
82        sf.protocol = lambda: serverProto
83
84        cf = Django()
85        cf.protocol = lambda: clientProto
86
87        serverTransport = ptcp.PTCP(sf)
88        clientTransport = ptcp.PTCP(None)
89
90        self.serverTransport = serverTransport
91        self.clientTransport = clientTransport
92
93        serverPort = reactor.listenUDP(0, serverTransport)
94        clientPort = reactor.listenUDP(0, clientTransport)
95
96        self.clientPort = clientPort
97        self.serverPort = serverPort
98
99        return (
100            serverProto, clientProto,
101            sf, cf,
102            serverTransport, clientTransport,
103            serverPort, clientPort
104            )
105
106    def tearDown(self):
107        td = []
108
109        for ptcpTransport in (self.serverTransport, self.clientTransport):
110            td.append(ptcpTransport.waitForAllConnectionsToClose())
111        d = defer.DeferredList(td)
112        return d
113
114
115class TestProducerProtocol(protocol.Protocol):
116    NUM_WRITES = 32
117    WRITE_SIZE = 32
118
119    def __init__(self):
120        self.onConnect = defer.Deferred()
121        self.onPaused = defer.Deferred()
122
123    def connectionMade(self):
124        self.onConnect.callback(None)
125        self.count = -1
126        self.transport.registerProducer(self, False)
127
128    def pauseProducing(self):
129        if self.onPaused is not None:
130            self.onPaused.callback(None)
131            self.onPaused = None
132
133    def resumeProducing(self):
134        self.count += 1
135        if self.count < self.NUM_WRITES:
136            bytes = chr(self.count) * self.WRITE_SIZE
137            # print 'Issuing a write', len(bytes)
138            self.transport.write(bytes)
139            if self.count == self.NUM_WRITES - 1:
140                # Last time through, intentionally drop the connection before
141                # the buffer is empty to ensure we handle this case properly.
142                # print 'Disconnecting'
143                self.transport.loseConnection()
144        else:
145            # print 'Unregistering'
146            self.transport.unregisterProducer()
147
148class PTCPTransportTestCase(ConnectedPTCPMixin, unittest.TestCase):
149    def setUp(self):
150        """
151        I have no idea why one of these values is divided by 10 and the
152        other is multiplied by 10.  -exarkun
153        """
154        self.patch(
155            ptcp.PTCPConnection, '_retransmitTimeout',
156            ptcp.PTCPConnection._retransmitTimeout / 10)
157        self.patch(
158            ptcp.PTCPPacket, 'retransmitCount',
159            ptcp.PTCPPacket.retransmitCount * 10)
160
161
162    def xtestWhoAmI(self):
163        (serverProto, clientProto,
164         sf, cf,
165         serverTransport, clientTransport,
166         serverPort, clientPort) = self.setUpForATest()
167
168        def gotAddress(results):
169            (serverSuccess, serverAddress), (clientSuccess, clientAddress) = results
170            self.failUnless(serverSuccess)
171            self.failUnless(clientSuccess)
172
173            self.assertEquals(serverAddress[1], serverPort.getHost().port)
174            self.assertEquals(clientAddress[1], clientPort.getHost().port)
175
176        def connectionsMade(ignored):
177            return defer.DeferredList([serverProto.transport.whoami(), clientProto.transport.whoami()]).addCallback(gotAddress)
178
179        clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port)
180
181        return defer.DeferredList([serverProto.onConnect, clientProto.onConnect]).addCallback(connectionsMade)
182
183    #testWhoAmI.skip = 'arglebargle'
184
185    def testVerySimpleConnection(self):
186        (serverProto, clientProto,
187         sf, cf,
188         serverTransport, clientTransport,
189         serverPort, clientPort) = self.setUpForATest()
190
191
192        clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port)
193
194        def sendSomeBytes(ignored, n=10, server=False):
195            if n:
196                bytes = 'not a lot of bytes' * 1000
197                if server:
198                    serverProto.transport.write(bytes)
199                else:
200                    clientProto.transport.write(bytes)
201                if server:
202                    clientProto.buffer = []
203                    d = clientProto.gotBytes(bytes)
204                else:
205                    serverProto.buffer = []
206                    d = serverProto.gotBytes(bytes)
207                return d.addCallback(sendSomeBytes, n - 1, not server)
208
209        def loseConnections(ignored):
210            serverProto.transport.loseConnection()
211            clientProto.transport.loseConnection()
212            return defer.DeferredList([
213                    serverProto.onDisconn,
214                    clientProto.onDisconn
215                    ])
216
217        dl = defer.DeferredList([serverProto.onConnect, clientProto.onConnect])
218        dl.addCallback(sendSomeBytes)
219        dl.addCallback(loseConnections)
220        return dl
221
222
223    def testProducerConsumer(self):
224        (serverProto, clientProto,
225         sf, cf,
226         serverTransport, clientTransport,
227         serverPort, clientPort) = self.setUpForATest(
228            ServerProtocol=TestProducerProtocol)
229
230        def disconnected(ignored):
231            self.assertEquals(
232                ''.join(clientProto.buffer),
233                ''.join([chr(n) * serverProto.WRITE_SIZE
234                         for n in range(serverProto.NUM_WRITES)]))
235
236        clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port)
237        return clientProto.onDisconn.addCallback(disconnected)
238
239
240    def testTransportProducer(self):
241        (serverProto, clientProto,
242         sf, cf,
243         serverTransport, clientTransport,
244         serverPort, clientPort) = self.setUpForATest()
245
246        resumed = []
247        def resumeProducing():
248            resumed.append(True)
249            clientProto.transport.resumeProducing()
250
251        def cbBytes(ignored):
252            self.failUnless(resumed)
253            clientProto.transport.loseConnection()
254
255        def cbConnect(ignored):
256            BYTES = 'Here are bytes'
257            clientProto.transport.pauseProducing()
258            serverProto.transport.write(BYTES)
259            reactor.callLater(2, resumeProducing)
260            return clientProto.gotBytes(BYTES).addCallback(cbBytes)
261
262
263        clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port)
264        connD = defer.DeferredList([clientProto.onConnect, serverProto.onConnect])
265        connD.addCallback(cbConnect)
266        return connD
267
268    def testTransportProducerProtocolProducer(self):
269        (serverProto, clientProto,
270         sf, cf,
271         serverTransport, clientTransport,
272         serverPort, clientPort) = self.setUpForATest(
273            ServerProtocol=TestProducerProtocol)
274
275        paused = []
276        def cbPaused(ignored):
277            # print 'Paused'
278            paused.append(True)
279            # print 'RESUMING', clientProto, clientTransport, clientPort
280            clientProto.transport.resumeProducing()
281        serverProto.onPaused.addCallback(cbPaused)
282
283        def cbBytes(ignored):
284            # print 'Disconnected'
285            self.assertEquals(
286                ''.join(clientProto.buffer),
287                ''.join([chr(n) * serverProto.WRITE_SIZE
288                         for n in range(serverProto.NUM_WRITES)]))
289
290        def cbConnect(ignored):
291            # The server must write enough to completely fill the outgoing buffer,
292            # since our peer isn't ACKing /anything/ and our server waits for
293            # writes to be acked before proceeding.
294            serverProto.WRITE_SIZE = serverProto.transport.sendWindow * 5
295
296            # print 'Connected'
297            # print 'PAUSING CLIENT PROTO', clientProto, clientTransport, clientPort
298            clientProto.transport.pauseProducing()
299            return clientProto.onDisconn.addCallback(cbBytes)
300
301        clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port)
302        connD = defer.DeferredList([clientProto.onConnect, serverProto.onConnect])
303        connD.addCallback(cbConnect)
304        return connD
305
306
307class LossyTransportTestCase(PTCPTransportTestCase):
308    def setUpForATest(self, *a, **kw):
309        results = PTCPTransportTestCase.setUpForATest(self, *a, **kw)
310        results[-2].write = reallyLossy(results[-2].write)
311        results[-1].write = reallyLossy(results[-1].write)
312        return results
313
314
315class SmallMTUTransportTestCase(PTCPTransportTestCase):
316    def setUpForATest(self, *a, **kw):
317        results = PTCPTransportTestCase.setUpForATest(self, *a, **kw)
318        results[-2].write = insufficientTransmitter(results[-2].write, 128)
319        results[-1].write = insufficientTransmitter(results[-1].write, 128)
320        return results
321
322
323
324class TimeoutTestCase(ConnectedPTCPMixin, unittest.TestCase):
325    def setUp(self):
326        """
327        Shorten the retransmit timeout so that tests finish more quickly.
328        """
329        self.patch(
330            ptcp.PTCPConnection, '_retransmitTimeout',
331            ptcp.PTCPConnection._retransmitTimeout / 10)
332
333
334    def testConnectTimeout(self):
335        (serverProto, clientProto,
336         sf, cf,
337         serverTransport, clientTransport,
338         serverPort, clientPort) = self.setUpForATest()
339
340        clientTransport.sendPacket = lambda *a, **kw: None
341        clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port)
342        return cf.onConnect.addBoth(lambda result: result.trap(error.TimeoutError) and None)
343
344    def testDataTimeout(self):
345        (serverProto, clientProto,
346         sf, cf,
347         serverTransport, clientTransport,
348         serverPort, clientPort) = self.setUpForATest()
349
350        def cbConnected(ignored):
351            serverProto.transport.ptcp.sendPacket = lambda *a, **kw: None
352            clientProto.transport.write('Receive this data.')
353            serverProto.transport.write('Send this data.') # have to send data
354                                                           # or the server will
355                                                           # never time out:
356                                                           # need a
357                                                           # SO_KEEPALIVE
358                                                           # option somewhere
359            return clientProto.onDisconn
360
361        clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port)
362
363        d = defer.DeferredList([serverProto.onConnect, clientProto.onConnect])
364        d.addCallback(cbConnected)
365        return d
366