1# Copyright (c) Twisted Matrix Laboratories.
2# See LICENSE for details.
3
4"""
5Test case for L{twisted.protocols.loopback}.
6"""
7
8from __future__ import division, absolute_import
9
10from zope.interface import implementer
11
12from twisted.python.compat import _PY3, intToBytes
13from twisted.trial import unittest
14from twisted.trial.util import suppress as SUPPRESS
15from twisted.protocols import basic, loopback
16from twisted.internet import defer
17from twisted.internet.protocol import Protocol
18from twisted.internet.defer import Deferred
19from twisted.internet.interfaces import IAddress, IPushProducer, IPullProducer
20from twisted.internet import reactor, interfaces
21
22
23class SimpleProtocol(basic.LineReceiver):
24    def __init__(self):
25        self.conn = defer.Deferred()
26        self.lines = []
27        self.connLost = []
28
29    def connectionMade(self):
30        self.conn.callback(None)
31
32    def lineReceived(self, line):
33        self.lines.append(line)
34
35    def connectionLost(self, reason):
36        self.connLost.append(reason)
37
38
39class DoomProtocol(SimpleProtocol):
40    i = 0
41    def lineReceived(self, line):
42        self.i += 1
43        if self.i < 4:
44            # by this point we should have connection closed,
45            # but just in case we didn't we won't ever send 'Hello 4'
46            self.sendLine(b"Hello " + intToBytes(self.i))
47        SimpleProtocol.lineReceived(self, line)
48        if self.lines[-1] == b"Hello 3":
49            self.transport.loseConnection()
50
51
52class LoopbackTestCaseMixin:
53    def testRegularFunction(self):
54        s = SimpleProtocol()
55        c = SimpleProtocol()
56
57        def sendALine(result):
58            s.sendLine(b"THIS IS LINE ONE!")
59            s.transport.loseConnection()
60        s.conn.addCallback(sendALine)
61
62        def check(ignored):
63            self.assertEqual(c.lines, [b"THIS IS LINE ONE!"])
64            self.assertEqual(len(s.connLost), 1)
65            self.assertEqual(len(c.connLost), 1)
66        d = defer.maybeDeferred(self.loopbackFunc, s, c)
67        d.addCallback(check)
68        return d
69
70    def testSneakyHiddenDoom(self):
71        s = DoomProtocol()
72        c = DoomProtocol()
73
74        def sendALine(result):
75            s.sendLine(b"DOOM LINE")
76        s.conn.addCallback(sendALine)
77
78        def check(ignored):
79            self.assertEqual(s.lines, [b'Hello 1', b'Hello 2', b'Hello 3'])
80            self.assertEqual(
81                c.lines, [b'DOOM LINE', b'Hello 1', b'Hello 2', b'Hello 3'])
82            self.assertEqual(len(s.connLost), 1)
83            self.assertEqual(len(c.connLost), 1)
84        d = defer.maybeDeferred(self.loopbackFunc, s, c)
85        d.addCallback(check)
86        return d
87
88
89
90class LoopbackAsyncTestCase(LoopbackTestCaseMixin, unittest.TestCase):
91    loopbackFunc = staticmethod(loopback.loopbackAsync)
92
93
94    def test_makeConnection(self):
95        """
96        Test that the client and server protocol both have makeConnection
97        invoked on them by loopbackAsync.
98        """
99        class TestProtocol(Protocol):
100            transport = None
101            def makeConnection(self, transport):
102                self.transport = transport
103
104        server = TestProtocol()
105        client = TestProtocol()
106        loopback.loopbackAsync(server, client)
107        self.failIfEqual(client.transport, None)
108        self.failIfEqual(server.transport, None)
109
110
111    def _hostpeertest(self, get, testServer):
112        """
113        Test one of the permutations of client/server host/peer.
114        """
115        class TestProtocol(Protocol):
116            def makeConnection(self, transport):
117                Protocol.makeConnection(self, transport)
118                self.onConnection.callback(transport)
119
120        if testServer:
121            server = TestProtocol()
122            d = server.onConnection = Deferred()
123            client = Protocol()
124        else:
125            server = Protocol()
126            client = TestProtocol()
127            d = client.onConnection = Deferred()
128
129        loopback.loopbackAsync(server, client)
130
131        def connected(transport):
132            host = getattr(transport, get)()
133            self.failUnless(IAddress.providedBy(host))
134
135        return d.addCallback(connected)
136
137
138    def test_serverHost(self):
139        """
140        Test that the server gets a transport with a properly functioning
141        implementation of L{ITransport.getHost}.
142        """
143        return self._hostpeertest("getHost", True)
144
145
146    def test_serverPeer(self):
147        """
148        Like C{test_serverHost} but for L{ITransport.getPeer}
149        """
150        return self._hostpeertest("getPeer", True)
151
152
153    def test_clientHost(self, get="getHost"):
154        """
155        Test that the client gets a transport with a properly functioning
156        implementation of L{ITransport.getHost}.
157        """
158        return self._hostpeertest("getHost", False)
159
160
161    def test_clientPeer(self):
162        """
163        Like C{test_clientHost} but for L{ITransport.getPeer}.
164        """
165        return self._hostpeertest("getPeer", False)
166
167
168    def _greetingtest(self, write, testServer):
169        """
170        Test one of the permutations of write/writeSequence client/server.
171
172        @param write: The name of the method to test, C{"write"} or
173            C{"writeSequence"}.
174        """
175        class GreeteeProtocol(Protocol):
176            bytes = b""
177            def dataReceived(self, bytes):
178                self.bytes += bytes
179                if self.bytes == b"bytes":
180                    self.received.callback(None)
181
182        class GreeterProtocol(Protocol):
183            def connectionMade(self):
184                if write == "write":
185                    self.transport.write(b"bytes")
186                else:
187                    self.transport.writeSequence([b"byt", b"es"])
188
189        if testServer:
190            server = GreeterProtocol()
191            client = GreeteeProtocol()
192            d = client.received = Deferred()
193        else:
194            server = GreeteeProtocol()
195            d = server.received = Deferred()
196            client = GreeterProtocol()
197
198        loopback.loopbackAsync(server, client)
199        return d
200
201
202    def test_clientGreeting(self):
203        """
204        Test that on a connection where the client speaks first, the server
205        receives the bytes sent by the client.
206        """
207        return self._greetingtest("write", False)
208
209
210    def test_clientGreetingSequence(self):
211        """
212        Like C{test_clientGreeting}, but use C{writeSequence} instead of
213        C{write} to issue the greeting.
214        """
215        return self._greetingtest("writeSequence", False)
216
217
218    def test_serverGreeting(self, write="write"):
219        """
220        Test that on a connection where the server speaks first, the client
221        receives the bytes sent by the server.
222        """
223        return self._greetingtest("write", True)
224
225
226    def test_serverGreetingSequence(self):
227        """
228        Like C{test_serverGreeting}, but use C{writeSequence} instead of
229        C{write} to issue the greeting.
230        """
231        return self._greetingtest("writeSequence", True)
232
233
234    def _producertest(self, producerClass):
235        toProduce = list(map(intToBytes, range(0, 10)))
236
237        class ProducingProtocol(Protocol):
238            def connectionMade(self):
239                self.producer = producerClass(list(toProduce))
240                self.producer.start(self.transport)
241
242        class ReceivingProtocol(Protocol):
243            bytes = b""
244            def dataReceived(self, data):
245                self.bytes += data
246                if self.bytes == b''.join(toProduce):
247                    self.received.callback((client, server))
248
249        server = ProducingProtocol()
250        client = ReceivingProtocol()
251        client.received = Deferred()
252
253        loopback.loopbackAsync(server, client)
254        return client.received
255
256
257    def test_pushProducer(self):
258        """
259        Test a push producer registered against a loopback transport.
260        """
261        @implementer(IPushProducer)
262        class PushProducer(object):
263            resumed = False
264
265            def __init__(self, toProduce):
266                self.toProduce = toProduce
267
268            def resumeProducing(self):
269                self.resumed = True
270
271            def start(self, consumer):
272                self.consumer = consumer
273                consumer.registerProducer(self, True)
274                self._produceAndSchedule()
275
276            def _produceAndSchedule(self):
277                if self.toProduce:
278                    self.consumer.write(self.toProduce.pop(0))
279                    reactor.callLater(0, self._produceAndSchedule)
280                else:
281                    self.consumer.unregisterProducer()
282        d = self._producertest(PushProducer)
283
284        def finished(results):
285            (client, server) = results
286            self.assertFalse(
287                server.producer.resumed,
288                "Streaming producer should not have been resumed.")
289        d.addCallback(finished)
290        return d
291
292
293    def test_pullProducer(self):
294        """
295        Test a pull producer registered against a loopback transport.
296        """
297        @implementer(IPullProducer)
298        class PullProducer(object):
299            def __init__(self, toProduce):
300                self.toProduce = toProduce
301
302            def start(self, consumer):
303                self.consumer = consumer
304                self.consumer.registerProducer(self, False)
305
306            def resumeProducing(self):
307                self.consumer.write(self.toProduce.pop(0))
308                if not self.toProduce:
309                    self.consumer.unregisterProducer()
310        return self._producertest(PullProducer)
311
312
313    def test_writeNotReentrant(self):
314        """
315        L{loopback.loopbackAsync} does not call a protocol's C{dataReceived}
316        method while that protocol's transport's C{write} method is higher up
317        on the stack.
318        """
319        class Server(Protocol):
320            def dataReceived(self, bytes):
321                self.transport.write(b"bytes")
322
323        class Client(Protocol):
324            ready = False
325
326            def connectionMade(self):
327                reactor.callLater(0, self.go)
328
329            def go(self):
330                self.transport.write(b"foo")
331                self.ready = True
332
333            def dataReceived(self, bytes):
334                self.wasReady = self.ready
335                self.transport.loseConnection()
336
337
338        server = Server()
339        client = Client()
340        d = loopback.loopbackAsync(client, server)
341        def cbFinished(ignored):
342            self.assertTrue(client.wasReady)
343        d.addCallback(cbFinished)
344        return d
345
346
347    def test_pumpPolicy(self):
348        """
349        The callable passed as the value for the C{pumpPolicy} parameter to
350        L{loopbackAsync} is called with a L{_LoopbackQueue} of pending bytes
351        and a protocol to which they should be delivered.
352        """
353        pumpCalls = []
354        def dummyPolicy(queue, target):
355            bytes = []
356            while queue:
357                bytes.append(queue.get())
358            pumpCalls.append((target, bytes))
359
360        client = Protocol()
361        server = Protocol()
362
363        finished = loopback.loopbackAsync(server, client, dummyPolicy)
364        self.assertEqual(pumpCalls, [])
365
366        client.transport.write(b"foo")
367        client.transport.write(b"bar")
368        server.transport.write(b"baz")
369        server.transport.write(b"quux")
370        server.transport.loseConnection()
371
372        def cbComplete(ignored):
373            self.assertEqual(
374                pumpCalls,
375                # The order here is somewhat arbitrary.  The implementation
376                # happens to always deliver data to the client first.
377                [(client, [b"baz", b"quux", None]),
378                 (server, [b"foo", b"bar"])])
379        finished.addCallback(cbComplete)
380        return finished
381
382
383    def test_identityPumpPolicy(self):
384        """
385        L{identityPumpPolicy} is a pump policy which calls the target's
386        C{dataReceived} method one for each string in the queue passed to it.
387        """
388        bytes = []
389        client = Protocol()
390        client.dataReceived = bytes.append
391        queue = loopback._LoopbackQueue()
392        queue.put(b"foo")
393        queue.put(b"bar")
394        queue.put(None)
395
396        loopback.identityPumpPolicy(queue, client)
397
398        self.assertEqual(bytes, [b"foo", b"bar"])
399
400
401    def test_collapsingPumpPolicy(self):
402        """
403        L{collapsingPumpPolicy} is a pump policy which calls the target's
404        C{dataReceived} only once with all of the strings in the queue passed
405        to it joined together.
406        """
407        bytes = []
408        client = Protocol()
409        client.dataReceived = bytes.append
410        queue = loopback._LoopbackQueue()
411        queue.put(b"foo")
412        queue.put(b"bar")
413        queue.put(None)
414
415        loopback.collapsingPumpPolicy(queue, client)
416
417        self.assertEqual(bytes, [b"foobar"])
418
419
420
421class LoopbackTCPTestCase(LoopbackTestCaseMixin, unittest.TestCase):
422    loopbackFunc = staticmethod(loopback.loopbackTCP)
423
424
425class LoopbackUNIXTestCase(LoopbackTestCaseMixin, unittest.TestCase):
426    loopbackFunc = staticmethod(loopback.loopbackUNIX)
427
428    if interfaces.IReactorUNIX(reactor, None) is None:
429        skip = "Current reactor does not support UNIX sockets"
430    elif _PY3:
431        skip = "UNIX sockets not supported on Python 3.  See #6136"
432