1# Copyright (c) Twisted Matrix Laboratories. 2# See LICENSE for details. 3 4""" 5Test case for L{twisted.protocols.loopback}. 6""" 7 8from __future__ import division, absolute_import 9 10from zope.interface import implementer 11 12from twisted.python.compat import _PY3, intToBytes 13from twisted.trial import unittest 14from twisted.trial.util import suppress as SUPPRESS 15from twisted.protocols import basic, loopback 16from twisted.internet import defer 17from twisted.internet.protocol import Protocol 18from twisted.internet.defer import Deferred 19from twisted.internet.interfaces import IAddress, IPushProducer, IPullProducer 20from twisted.internet import reactor, interfaces 21 22 23class SimpleProtocol(basic.LineReceiver): 24 def __init__(self): 25 self.conn = defer.Deferred() 26 self.lines = [] 27 self.connLost = [] 28 29 def connectionMade(self): 30 self.conn.callback(None) 31 32 def lineReceived(self, line): 33 self.lines.append(line) 34 35 def connectionLost(self, reason): 36 self.connLost.append(reason) 37 38 39class DoomProtocol(SimpleProtocol): 40 i = 0 41 def lineReceived(self, line): 42 self.i += 1 43 if self.i < 4: 44 # by this point we should have connection closed, 45 # but just in case we didn't we won't ever send 'Hello 4' 46 self.sendLine(b"Hello " + intToBytes(self.i)) 47 SimpleProtocol.lineReceived(self, line) 48 if self.lines[-1] == b"Hello 3": 49 self.transport.loseConnection() 50 51 52class LoopbackTestCaseMixin: 53 def testRegularFunction(self): 54 s = SimpleProtocol() 55 c = SimpleProtocol() 56 57 def sendALine(result): 58 s.sendLine(b"THIS IS LINE ONE!") 59 s.transport.loseConnection() 60 s.conn.addCallback(sendALine) 61 62 def check(ignored): 63 self.assertEqual(c.lines, [b"THIS IS LINE ONE!"]) 64 self.assertEqual(len(s.connLost), 1) 65 self.assertEqual(len(c.connLost), 1) 66 d = defer.maybeDeferred(self.loopbackFunc, s, c) 67 d.addCallback(check) 68 return d 69 70 def testSneakyHiddenDoom(self): 71 s = DoomProtocol() 72 c = DoomProtocol() 73 74 def sendALine(result): 75 s.sendLine(b"DOOM LINE") 76 s.conn.addCallback(sendALine) 77 78 def check(ignored): 79 self.assertEqual(s.lines, [b'Hello 1', b'Hello 2', b'Hello 3']) 80 self.assertEqual( 81 c.lines, [b'DOOM LINE', b'Hello 1', b'Hello 2', b'Hello 3']) 82 self.assertEqual(len(s.connLost), 1) 83 self.assertEqual(len(c.connLost), 1) 84 d = defer.maybeDeferred(self.loopbackFunc, s, c) 85 d.addCallback(check) 86 return d 87 88 89 90class LoopbackAsyncTestCase(LoopbackTestCaseMixin, unittest.TestCase): 91 loopbackFunc = staticmethod(loopback.loopbackAsync) 92 93 94 def test_makeConnection(self): 95 """ 96 Test that the client and server protocol both have makeConnection 97 invoked on them by loopbackAsync. 98 """ 99 class TestProtocol(Protocol): 100 transport = None 101 def makeConnection(self, transport): 102 self.transport = transport 103 104 server = TestProtocol() 105 client = TestProtocol() 106 loopback.loopbackAsync(server, client) 107 self.failIfEqual(client.transport, None) 108 self.failIfEqual(server.transport, None) 109 110 111 def _hostpeertest(self, get, testServer): 112 """ 113 Test one of the permutations of client/server host/peer. 114 """ 115 class TestProtocol(Protocol): 116 def makeConnection(self, transport): 117 Protocol.makeConnection(self, transport) 118 self.onConnection.callback(transport) 119 120 if testServer: 121 server = TestProtocol() 122 d = server.onConnection = Deferred() 123 client = Protocol() 124 else: 125 server = Protocol() 126 client = TestProtocol() 127 d = client.onConnection = Deferred() 128 129 loopback.loopbackAsync(server, client) 130 131 def connected(transport): 132 host = getattr(transport, get)() 133 self.failUnless(IAddress.providedBy(host)) 134 135 return d.addCallback(connected) 136 137 138 def test_serverHost(self): 139 """ 140 Test that the server gets a transport with a properly functioning 141 implementation of L{ITransport.getHost}. 142 """ 143 return self._hostpeertest("getHost", True) 144 145 146 def test_serverPeer(self): 147 """ 148 Like C{test_serverHost} but for L{ITransport.getPeer} 149 """ 150 return self._hostpeertest("getPeer", True) 151 152 153 def test_clientHost(self, get="getHost"): 154 """ 155 Test that the client gets a transport with a properly functioning 156 implementation of L{ITransport.getHost}. 157 """ 158 return self._hostpeertest("getHost", False) 159 160 161 def test_clientPeer(self): 162 """ 163 Like C{test_clientHost} but for L{ITransport.getPeer}. 164 """ 165 return self._hostpeertest("getPeer", False) 166 167 168 def _greetingtest(self, write, testServer): 169 """ 170 Test one of the permutations of write/writeSequence client/server. 171 172 @param write: The name of the method to test, C{"write"} or 173 C{"writeSequence"}. 174 """ 175 class GreeteeProtocol(Protocol): 176 bytes = b"" 177 def dataReceived(self, bytes): 178 self.bytes += bytes 179 if self.bytes == b"bytes": 180 self.received.callback(None) 181 182 class GreeterProtocol(Protocol): 183 def connectionMade(self): 184 if write == "write": 185 self.transport.write(b"bytes") 186 else: 187 self.transport.writeSequence([b"byt", b"es"]) 188 189 if testServer: 190 server = GreeterProtocol() 191 client = GreeteeProtocol() 192 d = client.received = Deferred() 193 else: 194 server = GreeteeProtocol() 195 d = server.received = Deferred() 196 client = GreeterProtocol() 197 198 loopback.loopbackAsync(server, client) 199 return d 200 201 202 def test_clientGreeting(self): 203 """ 204 Test that on a connection where the client speaks first, the server 205 receives the bytes sent by the client. 206 """ 207 return self._greetingtest("write", False) 208 209 210 def test_clientGreetingSequence(self): 211 """ 212 Like C{test_clientGreeting}, but use C{writeSequence} instead of 213 C{write} to issue the greeting. 214 """ 215 return self._greetingtest("writeSequence", False) 216 217 218 def test_serverGreeting(self, write="write"): 219 """ 220 Test that on a connection where the server speaks first, the client 221 receives the bytes sent by the server. 222 """ 223 return self._greetingtest("write", True) 224 225 226 def test_serverGreetingSequence(self): 227 """ 228 Like C{test_serverGreeting}, but use C{writeSequence} instead of 229 C{write} to issue the greeting. 230 """ 231 return self._greetingtest("writeSequence", True) 232 233 234 def _producertest(self, producerClass): 235 toProduce = list(map(intToBytes, range(0, 10))) 236 237 class ProducingProtocol(Protocol): 238 def connectionMade(self): 239 self.producer = producerClass(list(toProduce)) 240 self.producer.start(self.transport) 241 242 class ReceivingProtocol(Protocol): 243 bytes = b"" 244 def dataReceived(self, data): 245 self.bytes += data 246 if self.bytes == b''.join(toProduce): 247 self.received.callback((client, server)) 248 249 server = ProducingProtocol() 250 client = ReceivingProtocol() 251 client.received = Deferred() 252 253 loopback.loopbackAsync(server, client) 254 return client.received 255 256 257 def test_pushProducer(self): 258 """ 259 Test a push producer registered against a loopback transport. 260 """ 261 @implementer(IPushProducer) 262 class PushProducer(object): 263 resumed = False 264 265 def __init__(self, toProduce): 266 self.toProduce = toProduce 267 268 def resumeProducing(self): 269 self.resumed = True 270 271 def start(self, consumer): 272 self.consumer = consumer 273 consumer.registerProducer(self, True) 274 self._produceAndSchedule() 275 276 def _produceAndSchedule(self): 277 if self.toProduce: 278 self.consumer.write(self.toProduce.pop(0)) 279 reactor.callLater(0, self._produceAndSchedule) 280 else: 281 self.consumer.unregisterProducer() 282 d = self._producertest(PushProducer) 283 284 def finished(results): 285 (client, server) = results 286 self.assertFalse( 287 server.producer.resumed, 288 "Streaming producer should not have been resumed.") 289 d.addCallback(finished) 290 return d 291 292 293 def test_pullProducer(self): 294 """ 295 Test a pull producer registered against a loopback transport. 296 """ 297 @implementer(IPullProducer) 298 class PullProducer(object): 299 def __init__(self, toProduce): 300 self.toProduce = toProduce 301 302 def start(self, consumer): 303 self.consumer = consumer 304 self.consumer.registerProducer(self, False) 305 306 def resumeProducing(self): 307 self.consumer.write(self.toProduce.pop(0)) 308 if not self.toProduce: 309 self.consumer.unregisterProducer() 310 return self._producertest(PullProducer) 311 312 313 def test_writeNotReentrant(self): 314 """ 315 L{loopback.loopbackAsync} does not call a protocol's C{dataReceived} 316 method while that protocol's transport's C{write} method is higher up 317 on the stack. 318 """ 319 class Server(Protocol): 320 def dataReceived(self, bytes): 321 self.transport.write(b"bytes") 322 323 class Client(Protocol): 324 ready = False 325 326 def connectionMade(self): 327 reactor.callLater(0, self.go) 328 329 def go(self): 330 self.transport.write(b"foo") 331 self.ready = True 332 333 def dataReceived(self, bytes): 334 self.wasReady = self.ready 335 self.transport.loseConnection() 336 337 338 server = Server() 339 client = Client() 340 d = loopback.loopbackAsync(client, server) 341 def cbFinished(ignored): 342 self.assertTrue(client.wasReady) 343 d.addCallback(cbFinished) 344 return d 345 346 347 def test_pumpPolicy(self): 348 """ 349 The callable passed as the value for the C{pumpPolicy} parameter to 350 L{loopbackAsync} is called with a L{_LoopbackQueue} of pending bytes 351 and a protocol to which they should be delivered. 352 """ 353 pumpCalls = [] 354 def dummyPolicy(queue, target): 355 bytes = [] 356 while queue: 357 bytes.append(queue.get()) 358 pumpCalls.append((target, bytes)) 359 360 client = Protocol() 361 server = Protocol() 362 363 finished = loopback.loopbackAsync(server, client, dummyPolicy) 364 self.assertEqual(pumpCalls, []) 365 366 client.transport.write(b"foo") 367 client.transport.write(b"bar") 368 server.transport.write(b"baz") 369 server.transport.write(b"quux") 370 server.transport.loseConnection() 371 372 def cbComplete(ignored): 373 self.assertEqual( 374 pumpCalls, 375 # The order here is somewhat arbitrary. The implementation 376 # happens to always deliver data to the client first. 377 [(client, [b"baz", b"quux", None]), 378 (server, [b"foo", b"bar"])]) 379 finished.addCallback(cbComplete) 380 return finished 381 382 383 def test_identityPumpPolicy(self): 384 """ 385 L{identityPumpPolicy} is a pump policy which calls the target's 386 C{dataReceived} method one for each string in the queue passed to it. 387 """ 388 bytes = [] 389 client = Protocol() 390 client.dataReceived = bytes.append 391 queue = loopback._LoopbackQueue() 392 queue.put(b"foo") 393 queue.put(b"bar") 394 queue.put(None) 395 396 loopback.identityPumpPolicy(queue, client) 397 398 self.assertEqual(bytes, [b"foo", b"bar"]) 399 400 401 def test_collapsingPumpPolicy(self): 402 """ 403 L{collapsingPumpPolicy} is a pump policy which calls the target's 404 C{dataReceived} only once with all of the strings in the queue passed 405 to it joined together. 406 """ 407 bytes = [] 408 client = Protocol() 409 client.dataReceived = bytes.append 410 queue = loopback._LoopbackQueue() 411 queue.put(b"foo") 412 queue.put(b"bar") 413 queue.put(None) 414 415 loopback.collapsingPumpPolicy(queue, client) 416 417 self.assertEqual(bytes, [b"foobar"]) 418 419 420 421class LoopbackTCPTestCase(LoopbackTestCaseMixin, unittest.TestCase): 422 loopbackFunc = staticmethod(loopback.loopbackTCP) 423 424 425class LoopbackUNIXTestCase(LoopbackTestCaseMixin, unittest.TestCase): 426 loopbackFunc = staticmethod(loopback.loopbackUNIX) 427 428 if interfaces.IReactorUNIX(reactor, None) is None: 429 skip = "Current reactor does not support UNIX sockets" 430 elif _PY3: 431 skip = "UNIX sockets not supported on Python 3. See #6136" 432