1# -*- test-case-name: vertex.test.test_ptcp -*- 2 3import random, os 4 5from twisted.internet import reactor, protocol, defer, error 6from twisted.trial import unittest 7 8from vertex import ptcp 9 10def reallyLossy(method): 11 r = random.Random() 12 r.seed(42) 13 def worseMethod(*a, **kw): 14 if r.choice([True, True, False]): 15 method(*a, **kw) 16 return worseMethod 17 18def insufficientTransmitter(method, mtu): 19 def worseMethod(bytes, addr): 20 method(bytes[:mtu], addr) 21 return worseMethod 22 23 24class TestProtocol(protocol.Protocol): 25 buffer = None 26 def __init__(self): 27 self.onConnect = defer.Deferred() 28 self.onDisconn = defer.Deferred() 29 self._waiting = None 30 self.buffer = [] 31 32 def connectionMade(self): 33 self.onConnect.callback(None) 34 35 def connectionLost(self, reason): 36 self.onDisconn.callback(None) 37 38 def gotBytes(self, bytes): 39 assert self._waiting is None 40 if ''.join(self.buffer) == bytes: 41 return defer.succeed(None) 42 self._waiting = (defer.Deferred(), bytes) 43 return self._waiting[0] 44 45 def dataReceived(self, bytes): 46 self.buffer.append(bytes) 47 if self._waiting is not None: 48 bytes = ''.join(self.buffer) 49 if not self._waiting[1].startswith(bytes): 50 x = len(os.path.commonprefix([bytes, self._waiting[1]])) 51 print x 52 print 'it goes wrong starting with', repr(bytes[x:x+100]), repr(self._waiting[1][x:x+100]) 53 if bytes == self._waiting[1]: 54 self._waiting[0].callback(None) 55 self._waiting = None 56 57class Django(protocol.ClientFactory): 58 def __init__(self): 59 self.onConnect = defer.Deferred() 60 61 def buildProtocol(self, addr): 62 p = protocol.ClientFactory.buildProtocol(self, addr) 63 self.onConnect.callback(p) 64 return p 65 66 def clientConnectionFailed(self, conn, err): 67 self.onConnect.errback(err) 68 69class ConnectedPTCPMixin: 70 serverPort = None 71 72 def setUpForATest(self, 73 ServerProtocol=TestProtocol, ClientProtocol=TestProtocol): 74 serverProto = ServerProtocol() 75 clientProto = ClientProtocol() 76 77 78 self.serverProto = serverProto 79 self.clientProto = clientProto 80 81 sf = protocol.ServerFactory() 82 sf.protocol = lambda: serverProto 83 84 cf = Django() 85 cf.protocol = lambda: clientProto 86 87 serverTransport = ptcp.PTCP(sf) 88 clientTransport = ptcp.PTCP(None) 89 90 self.serverTransport = serverTransport 91 self.clientTransport = clientTransport 92 93 serverPort = reactor.listenUDP(0, serverTransport) 94 clientPort = reactor.listenUDP(0, clientTransport) 95 96 self.clientPort = clientPort 97 self.serverPort = serverPort 98 99 return ( 100 serverProto, clientProto, 101 sf, cf, 102 serverTransport, clientTransport, 103 serverPort, clientPort 104 ) 105 106 def tearDown(self): 107 td = [] 108 109 for ptcpTransport in (self.serverTransport, self.clientTransport): 110 td.append(ptcpTransport.waitForAllConnectionsToClose()) 111 d = defer.DeferredList(td) 112 return d 113 114 115class TestProducerProtocol(protocol.Protocol): 116 NUM_WRITES = 32 117 WRITE_SIZE = 32 118 119 def __init__(self): 120 self.onConnect = defer.Deferred() 121 self.onPaused = defer.Deferred() 122 123 def connectionMade(self): 124 self.onConnect.callback(None) 125 self.count = -1 126 self.transport.registerProducer(self, False) 127 128 def pauseProducing(self): 129 if self.onPaused is not None: 130 self.onPaused.callback(None) 131 self.onPaused = None 132 133 def resumeProducing(self): 134 self.count += 1 135 if self.count < self.NUM_WRITES: 136 bytes = chr(self.count) * self.WRITE_SIZE 137 # print 'Issuing a write', len(bytes) 138 self.transport.write(bytes) 139 if self.count == self.NUM_WRITES - 1: 140 # Last time through, intentionally drop the connection before 141 # the buffer is empty to ensure we handle this case properly. 142 # print 'Disconnecting' 143 self.transport.loseConnection() 144 else: 145 # print 'Unregistering' 146 self.transport.unregisterProducer() 147 148class PTCPTransportTestCase(ConnectedPTCPMixin, unittest.TestCase): 149 def setUp(self): 150 """ 151 I have no idea why one of these values is divided by 10 and the 152 other is multiplied by 10. -exarkun 153 """ 154 self.patch( 155 ptcp.PTCPConnection, '_retransmitTimeout', 156 ptcp.PTCPConnection._retransmitTimeout / 10) 157 self.patch( 158 ptcp.PTCPPacket, 'retransmitCount', 159 ptcp.PTCPPacket.retransmitCount * 10) 160 161 162 def xtestWhoAmI(self): 163 (serverProto, clientProto, 164 sf, cf, 165 serverTransport, clientTransport, 166 serverPort, clientPort) = self.setUpForATest() 167 168 def gotAddress(results): 169 (serverSuccess, serverAddress), (clientSuccess, clientAddress) = results 170 self.failUnless(serverSuccess) 171 self.failUnless(clientSuccess) 172 173 self.assertEquals(serverAddress[1], serverPort.getHost().port) 174 self.assertEquals(clientAddress[1], clientPort.getHost().port) 175 176 def connectionsMade(ignored): 177 return defer.DeferredList([serverProto.transport.whoami(), clientProto.transport.whoami()]).addCallback(gotAddress) 178 179 clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port) 180 181 return defer.DeferredList([serverProto.onConnect, clientProto.onConnect]).addCallback(connectionsMade) 182 183 #testWhoAmI.skip = 'arglebargle' 184 185 def testVerySimpleConnection(self): 186 (serverProto, clientProto, 187 sf, cf, 188 serverTransport, clientTransport, 189 serverPort, clientPort) = self.setUpForATest() 190 191 192 clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port) 193 194 def sendSomeBytes(ignored, n=10, server=False): 195 if n: 196 bytes = 'not a lot of bytes' * 1000 197 if server: 198 serverProto.transport.write(bytes) 199 else: 200 clientProto.transport.write(bytes) 201 if server: 202 clientProto.buffer = [] 203 d = clientProto.gotBytes(bytes) 204 else: 205 serverProto.buffer = [] 206 d = serverProto.gotBytes(bytes) 207 return d.addCallback(sendSomeBytes, n - 1, not server) 208 209 def loseConnections(ignored): 210 serverProto.transport.loseConnection() 211 clientProto.transport.loseConnection() 212 return defer.DeferredList([ 213 serverProto.onDisconn, 214 clientProto.onDisconn 215 ]) 216 217 dl = defer.DeferredList([serverProto.onConnect, clientProto.onConnect]) 218 dl.addCallback(sendSomeBytes) 219 dl.addCallback(loseConnections) 220 return dl 221 222 223 def testProducerConsumer(self): 224 (serverProto, clientProto, 225 sf, cf, 226 serverTransport, clientTransport, 227 serverPort, clientPort) = self.setUpForATest( 228 ServerProtocol=TestProducerProtocol) 229 230 def disconnected(ignored): 231 self.assertEquals( 232 ''.join(clientProto.buffer), 233 ''.join([chr(n) * serverProto.WRITE_SIZE 234 for n in range(serverProto.NUM_WRITES)])) 235 236 clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port) 237 return clientProto.onDisconn.addCallback(disconnected) 238 239 240 def testTransportProducer(self): 241 (serverProto, clientProto, 242 sf, cf, 243 serverTransport, clientTransport, 244 serverPort, clientPort) = self.setUpForATest() 245 246 resumed = [] 247 def resumeProducing(): 248 resumed.append(True) 249 clientProto.transport.resumeProducing() 250 251 def cbBytes(ignored): 252 self.failUnless(resumed) 253 clientProto.transport.loseConnection() 254 255 def cbConnect(ignored): 256 BYTES = 'Here are bytes' 257 clientProto.transport.pauseProducing() 258 serverProto.transport.write(BYTES) 259 reactor.callLater(2, resumeProducing) 260 return clientProto.gotBytes(BYTES).addCallback(cbBytes) 261 262 263 clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port) 264 connD = defer.DeferredList([clientProto.onConnect, serverProto.onConnect]) 265 connD.addCallback(cbConnect) 266 return connD 267 268 def testTransportProducerProtocolProducer(self): 269 (serverProto, clientProto, 270 sf, cf, 271 serverTransport, clientTransport, 272 serverPort, clientPort) = self.setUpForATest( 273 ServerProtocol=TestProducerProtocol) 274 275 paused = [] 276 def cbPaused(ignored): 277 # print 'Paused' 278 paused.append(True) 279 # print 'RESUMING', clientProto, clientTransport, clientPort 280 clientProto.transport.resumeProducing() 281 serverProto.onPaused.addCallback(cbPaused) 282 283 def cbBytes(ignored): 284 # print 'Disconnected' 285 self.assertEquals( 286 ''.join(clientProto.buffer), 287 ''.join([chr(n) * serverProto.WRITE_SIZE 288 for n in range(serverProto.NUM_WRITES)])) 289 290 def cbConnect(ignored): 291 # The server must write enough to completely fill the outgoing buffer, 292 # since our peer isn't ACKing /anything/ and our server waits for 293 # writes to be acked before proceeding. 294 serverProto.WRITE_SIZE = serverProto.transport.sendWindow * 5 295 296 # print 'Connected' 297 # print 'PAUSING CLIENT PROTO', clientProto, clientTransport, clientPort 298 clientProto.transport.pauseProducing() 299 return clientProto.onDisconn.addCallback(cbBytes) 300 301 clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port) 302 connD = defer.DeferredList([clientProto.onConnect, serverProto.onConnect]) 303 connD.addCallback(cbConnect) 304 return connD 305 306 307class LossyTransportTestCase(PTCPTransportTestCase): 308 def setUpForATest(self, *a, **kw): 309 results = PTCPTransportTestCase.setUpForATest(self, *a, **kw) 310 results[-2].write = reallyLossy(results[-2].write) 311 results[-1].write = reallyLossy(results[-1].write) 312 return results 313 314 315class SmallMTUTransportTestCase(PTCPTransportTestCase): 316 def setUpForATest(self, *a, **kw): 317 results = PTCPTransportTestCase.setUpForATest(self, *a, **kw) 318 results[-2].write = insufficientTransmitter(results[-2].write, 128) 319 results[-1].write = insufficientTransmitter(results[-1].write, 128) 320 return results 321 322 323 324class TimeoutTestCase(ConnectedPTCPMixin, unittest.TestCase): 325 def setUp(self): 326 """ 327 Shorten the retransmit timeout so that tests finish more quickly. 328 """ 329 self.patch( 330 ptcp.PTCPConnection, '_retransmitTimeout', 331 ptcp.PTCPConnection._retransmitTimeout / 10) 332 333 334 def testConnectTimeout(self): 335 (serverProto, clientProto, 336 sf, cf, 337 serverTransport, clientTransport, 338 serverPort, clientPort) = self.setUpForATest() 339 340 clientTransport.sendPacket = lambda *a, **kw: None 341 clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port) 342 return cf.onConnect.addBoth(lambda result: result.trap(error.TimeoutError) and None) 343 344 def testDataTimeout(self): 345 (serverProto, clientProto, 346 sf, cf, 347 serverTransport, clientTransport, 348 serverPort, clientPort) = self.setUpForATest() 349 350 def cbConnected(ignored): 351 serverProto.transport.ptcp.sendPacket = lambda *a, **kw: None 352 clientProto.transport.write('Receive this data.') 353 serverProto.transport.write('Send this data.') # have to send data 354 # or the server will 355 # never time out: 356 # need a 357 # SO_KEEPALIVE 358 # option somewhere 359 return clientProto.onDisconn 360 361 clientTransport.connect(cf, '127.0.0.1', serverPort.getHost().port) 362 363 d = defer.DeferredList([serverProto.onConnect, clientProto.onConnect]) 364 d.addCallback(cbConnected) 365 return d 366