1# Copyright (c) Twisted Matrix Laboratories.
2# See LICENSE for details.
3
4"""
5Tests for implementations of L{IReactorUNIX} and L{IReactorUNIXDatagram}.
6"""
7
8import stat, os, sys, types
9import socket
10
11from twisted.internet import interfaces, reactor, protocol, error, address, defer, utils
12from twisted.python import lockfile
13from twisted.trial import unittest
14
15from twisted.test.test_tcp import MyServerFactory, MyClientFactory
16
17
18class FailedConnectionClientFactory(protocol.ClientFactory):
19    def __init__(self, onFail):
20        self.onFail = onFail
21
22    def clientConnectionFailed(self, connector, reason):
23        self.onFail.errback(reason)
24
25
26
27class UnixSocketTestCase(unittest.TestCase):
28    """
29    Test unix sockets.
30    """
31    def test_peerBind(self):
32        """
33        The address passed to the server factory's C{buildProtocol} method and
34        the address returned by the connected protocol's transport's C{getPeer}
35        method match the address the client socket is bound to.
36        """
37        filename = self.mktemp()
38        peername = self.mktemp()
39        serverFactory = MyServerFactory()
40        connMade = serverFactory.protocolConnectionMade = defer.Deferred()
41        unixPort = reactor.listenUNIX(filename, serverFactory)
42        self.addCleanup(unixPort.stopListening)
43        unixSocket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
44        self.addCleanup(unixSocket.close)
45        unixSocket.bind(peername)
46        unixSocket.connect(filename)
47        def cbConnMade(proto):
48            expected = address.UNIXAddress(peername)
49            self.assertEqual(serverFactory.peerAddresses, [expected])
50            self.assertEqual(proto.transport.getPeer(), expected)
51        connMade.addCallback(cbConnMade)
52        return connMade
53
54
55    def test_dumber(self):
56        """
57        L{IReactorUNIX.connectUNIX} can be used to connect a client to a server
58        started with L{IReactorUNIX.listenUNIX}.
59        """
60        filename = self.mktemp()
61        serverFactory = MyServerFactory()
62        serverConnMade = defer.Deferred()
63        serverFactory.protocolConnectionMade = serverConnMade
64        unixPort = reactor.listenUNIX(filename, serverFactory)
65        self.addCleanup(unixPort.stopListening)
66        clientFactory = MyClientFactory()
67        clientConnMade = defer.Deferred()
68        clientFactory.protocolConnectionMade = clientConnMade
69        c = reactor.connectUNIX(filename, clientFactory)
70        d = defer.gatherResults([serverConnMade, clientConnMade])
71        def allConnected((serverProtocol, clientProtocol)):
72
73            # Incidental assertion which may or may not be redundant with some
74            # other test.  This probably deserves its own test method.
75            self.assertEqual(clientFactory.peerAddresses,
76                             [address.UNIXAddress(filename)])
77
78            clientProtocol.transport.loseConnection()
79            serverProtocol.transport.loseConnection()
80        d.addCallback(allConnected)
81        return d
82
83
84    def test_pidFile(self):
85        """
86        A lockfile is created and locked when L{IReactorUNIX.listenUNIX} is
87        called and released when the Deferred returned by the L{IListeningPort}
88        provider's C{stopListening} method is called back.
89        """
90        filename = self.mktemp()
91        serverFactory = MyServerFactory()
92        serverConnMade = defer.Deferred()
93        serverFactory.protocolConnectionMade = serverConnMade
94        unixPort = reactor.listenUNIX(filename, serverFactory, wantPID=True)
95        self.assertTrue(lockfile.isLocked(filename + ".lock"))
96
97        # XXX This part would test something about the checkPID parameter, but
98        # it doesn't actually.  It should be rewritten to test the several
99        # different possible behaviors.  -exarkun
100        clientFactory = MyClientFactory()
101        clientConnMade = defer.Deferred()
102        clientFactory.protocolConnectionMade = clientConnMade
103        c = reactor.connectUNIX(filename, clientFactory, checkPID=1)
104
105        d = defer.gatherResults([serverConnMade, clientConnMade])
106        def _portStuff((serverProtocol, clientProto)):
107
108            # Incidental assertion which may or may not be redundant with some
109            # other test.  This probably deserves its own test method.
110            self.assertEqual(clientFactory.peerAddresses,
111                             [address.UNIXAddress(filename)])
112
113            clientProto.transport.loseConnection()
114            serverProtocol.transport.loseConnection()
115            return unixPort.stopListening()
116        d.addCallback(_portStuff)
117
118        def _check(ignored):
119            self.failIf(lockfile.isLocked(filename + ".lock"), 'locked')
120        d.addCallback(_check)
121        return d
122
123
124    def test_socketLocking(self):
125        """
126        L{IReactorUNIX.listenUNIX} raises L{error.CannotListenError} if passed
127        the name of a file on which a server is already listening.
128        """
129        filename = self.mktemp()
130        serverFactory = MyServerFactory()
131        unixPort = reactor.listenUNIX(filename, serverFactory, wantPID=True)
132
133        self.assertRaises(
134            error.CannotListenError,
135            reactor.listenUNIX, filename, serverFactory, wantPID=True)
136
137        def stoppedListening(ign):
138            unixPort = reactor.listenUNIX(filename, serverFactory, wantPID=True)
139            return unixPort.stopListening()
140
141        return unixPort.stopListening().addCallback(stoppedListening)
142
143
144    def _uncleanSocketTest(self, callback):
145        self.filename = self.mktemp()
146        source = ("from twisted.internet import protocol, reactor\n"
147                  "reactor.listenUNIX(%r, protocol.ServerFactory(), wantPID=True)\n") % (self.filename,)
148        env = {'PYTHONPATH': os.pathsep.join(sys.path)}
149
150        d = utils.getProcessValue(sys.executable, ("-u", "-c", source), env=env)
151        d.addCallback(callback)
152        return d
153
154
155    def test_uncleanServerSocketLocking(self):
156        """
157        If passed C{True} for the C{wantPID} parameter, a server can be started
158        listening with L{IReactorUNIX.listenUNIX} when passed the name of a
159        file on which a previous server which has not exited cleanly has been
160        listening using the C{wantPID} option.
161        """
162        def ranStupidChild(ign):
163            # If this next call succeeds, our lock handling is correct.
164            p = reactor.listenUNIX(self.filename, MyServerFactory(), wantPID=True)
165            return p.stopListening()
166        return self._uncleanSocketTest(ranStupidChild)
167
168
169    def test_connectToUncleanServer(self):
170        """
171        If passed C{True} for the C{checkPID} parameter, a client connection
172        attempt made with L{IReactorUNIX.connectUNIX} fails with
173        L{error.BadFileError}.
174        """
175        def ranStupidChild(ign):
176            d = defer.Deferred()
177            f = FailedConnectionClientFactory(d)
178            c = reactor.connectUNIX(self.filename, f, checkPID=True)
179            return self.assertFailure(d, error.BadFileError)
180        return self._uncleanSocketTest(ranStupidChild)
181
182
183    def _reprTest(self, serverFactory, factoryName):
184        """
185        Test the C{__str__} and C{__repr__} implementations of a UNIX port when
186        used with the given factory.
187        """
188        filename = self.mktemp()
189        unixPort = reactor.listenUNIX(filename, serverFactory)
190
191        connectedString = "<%s on %r>" % (factoryName, filename)
192        self.assertEqual(repr(unixPort), connectedString)
193        self.assertEqual(str(unixPort), connectedString)
194
195        d = defer.maybeDeferred(unixPort.stopListening)
196        def stoppedListening(ign):
197            unconnectedString = "<%s (not listening)>" % (factoryName,)
198            self.assertEqual(repr(unixPort), unconnectedString)
199            self.assertEqual(str(unixPort), unconnectedString)
200        d.addCallback(stoppedListening)
201        return d
202
203
204    def test_reprWithClassicFactory(self):
205        """
206        The two string representations of the L{IListeningPort} returned by
207        L{IReactorUNIX.listenUNIX} contains the name of the classic factory
208        class being used and the filename on which the port is listening or
209        indicates that the port is not listening.
210        """
211        class ClassicFactory:
212            def doStart(self):
213                pass
214
215            def doStop(self):
216                pass
217
218        # Sanity check
219        self.assertIsInstance(ClassicFactory, types.ClassType)
220
221        return self._reprTest(
222            ClassicFactory(), "twisted.test.test_unix.ClassicFactory")
223
224
225    def test_reprWithNewStyleFactory(self):
226        """
227        The two string representations of the L{IListeningPort} returned by
228        L{IReactorUNIX.listenUNIX} contains the name of the new-style factory
229        class being used and the filename on which the port is listening or
230        indicates that the port is not listening.
231        """
232        class NewStyleFactory(object):
233            def doStart(self):
234                pass
235
236            def doStop(self):
237                pass
238
239        # Sanity check
240        self.assertIsInstance(NewStyleFactory, type)
241
242        return self._reprTest(
243            NewStyleFactory(), "twisted.test.test_unix.NewStyleFactory")
244
245
246
247class ClientProto(protocol.ConnectedDatagramProtocol):
248    started = stopped = False
249    gotback = None
250
251    def __init__(self):
252        self.deferredStarted = defer.Deferred()
253        self.deferredGotBack = defer.Deferred()
254
255    def stopProtocol(self):
256        self.stopped = True
257
258    def startProtocol(self):
259        self.started = True
260        self.deferredStarted.callback(None)
261
262    def datagramReceived(self, data):
263        self.gotback = data
264        self.deferredGotBack.callback(None)
265
266class ServerProto(protocol.DatagramProtocol):
267    started = stopped = False
268    gotwhat = gotfrom = None
269
270    def __init__(self):
271        self.deferredStarted = defer.Deferred()
272        self.deferredGotWhat = defer.Deferred()
273
274    def stopProtocol(self):
275        self.stopped = True
276
277    def startProtocol(self):
278        self.started = True
279        self.deferredStarted.callback(None)
280
281    def datagramReceived(self, data, addr):
282        self.gotfrom = addr
283        self.transport.write("hi back", addr)
284        self.gotwhat = data
285        self.deferredGotWhat.callback(None)
286
287
288
289class DatagramUnixSocketTestCase(unittest.TestCase):
290    """
291    Test datagram UNIX sockets.
292    """
293    def test_exchange(self):
294        """
295        Test that a datagram can be sent to and received by a server and vice
296        versa.
297        """
298        clientaddr = self.mktemp()
299        serveraddr = self.mktemp()
300        sp = ServerProto()
301        cp = ClientProto()
302        s = reactor.listenUNIXDatagram(serveraddr, sp)
303        self.addCleanup(s.stopListening)
304        c = reactor.connectUNIXDatagram(serveraddr, cp, bindAddress=clientaddr)
305        self.addCleanup(c.stopListening)
306
307        d = defer.gatherResults([sp.deferredStarted, cp.deferredStarted])
308        def write(ignored):
309            cp.transport.write("hi")
310            return defer.gatherResults([sp.deferredGotWhat,
311                                        cp.deferredGotBack])
312
313        def _cbTestExchange(ignored):
314            self.assertEqual("hi", sp.gotwhat)
315            self.assertEqual(clientaddr, sp.gotfrom)
316            self.assertEqual("hi back", cp.gotback)
317
318        d.addCallback(write)
319        d.addCallback(_cbTestExchange)
320        return d
321
322
323    def test_cannotListen(self):
324        """
325        L{IReactorUNIXDatagram.listenUNIXDatagram} raises
326        L{error.CannotListenError} if the unix socket specified is already in
327        use.
328        """
329        addr = self.mktemp()
330        p = ServerProto()
331        s = reactor.listenUNIXDatagram(addr, p)
332        self.failUnlessRaises(error.CannotListenError, reactor.listenUNIXDatagram, addr, p)
333        s.stopListening()
334        os.unlink(addr)
335
336    # test connecting to bound and connected (somewhere else) address
337
338    def _reprTest(self, serverProto, protocolName):
339        """
340        Test the C{__str__} and C{__repr__} implementations of a UNIX datagram
341        port when used with the given protocol.
342        """
343        filename = self.mktemp()
344        unixPort = reactor.listenUNIXDatagram(filename, serverProto)
345
346        connectedString = "<%s on %r>" % (protocolName, filename)
347        self.assertEqual(repr(unixPort), connectedString)
348        self.assertEqual(str(unixPort), connectedString)
349
350        stopDeferred = defer.maybeDeferred(unixPort.stopListening)
351        def stoppedListening(ign):
352            unconnectedString = "<%s (not listening)>" % (protocolName,)
353            self.assertEqual(repr(unixPort), unconnectedString)
354            self.assertEqual(str(unixPort), unconnectedString)
355        stopDeferred.addCallback(stoppedListening)
356        return stopDeferred
357
358
359    def test_reprWithClassicProtocol(self):
360        """
361        The two string representations of the L{IListeningPort} returned by
362        L{IReactorUNIXDatagram.listenUNIXDatagram} contains the name of the
363        classic protocol class being used and the filename on which the port is
364        listening or indicates that the port is not listening.
365        """
366        class ClassicProtocol:
367            def makeConnection(self, transport):
368                pass
369
370            def doStop(self):
371                pass
372
373        # Sanity check
374        self.assertIsInstance(ClassicProtocol, types.ClassType)
375
376        return self._reprTest(
377            ClassicProtocol(), "twisted.test.test_unix.ClassicProtocol")
378
379
380    def test_reprWithNewStyleProtocol(self):
381        """
382        The two string representations of the L{IListeningPort} returned by
383        L{IReactorUNIXDatagram.listenUNIXDatagram} contains the name of the
384        new-style protocol class being used and the filename on which the port
385        is listening or indicates that the port is not listening.
386        """
387        class NewStyleProtocol(object):
388            def makeConnection(self, transport):
389                pass
390
391            def doStop(self):
392                pass
393
394        # Sanity check
395        self.assertIsInstance(NewStyleProtocol, type)
396
397        return self._reprTest(
398            NewStyleProtocol(), "twisted.test.test_unix.NewStyleProtocol")
399
400
401
402if not interfaces.IReactorUNIX(reactor, None):
403    UnixSocketTestCase.skip = "This reactor does not support UNIX domain sockets"
404if not interfaces.IReactorUNIXDatagram(reactor, None):
405    DatagramUnixSocketTestCase.skip = "This reactor does not support UNIX datagram sockets"
406