1# Copyright (c) Twisted Matrix Laboratories.
2# See LICENSE for details.
3
4"""
5Tests for L{twisted.internet.stdio}.
6"""
7
8import os, sys, itertools
9
10from twisted.trial import unittest
11from twisted.python import filepath, log
12from twisted.python.runtime import platform
13from twisted.internet import error, defer, protocol, stdio, reactor
14from twisted.test.test_tcp import ConnectionLostNotifyingProtocol
15
16
17# A short string which is intended to appear here and nowhere else,
18# particularly not in any random garbage output CPython unavoidable
19# generates (such as in warning text and so forth).  This is searched
20# for in the output from stdio_test_lastwrite.py and if it is found at
21# the end, the functionality works.
22UNIQUE_LAST_WRITE_STRING = 'xyz123abc Twisted is great!'
23
24skipWindowsNopywin32 = None
25if platform.isWindows():
26    try:
27        import win32process
28    except ImportError:
29        skipWindowsNopywin32 = ("On windows, spawnProcess is not available "
30                                "in the absence of win32process.")
31
32
33class StandardIOTestProcessProtocol(protocol.ProcessProtocol):
34    """
35    Test helper for collecting output from a child process and notifying
36    something when it exits.
37
38    @ivar onConnection: A L{defer.Deferred} which will be called back with
39    C{None} when the connection to the child process is established.
40
41    @ivar onCompletion: A L{defer.Deferred} which will be errbacked with the
42    failure associated with the child process exiting when it exits.
43
44    @ivar onDataReceived: A L{defer.Deferred} which will be called back with
45    this instance whenever C{childDataReceived} is called, or C{None} to
46    suppress these callbacks.
47
48    @ivar data: A C{dict} mapping file descriptors to strings containing all
49    bytes received from the child process on each file descriptor.
50    """
51    onDataReceived = None
52
53    def __init__(self):
54        self.onConnection = defer.Deferred()
55        self.onCompletion = defer.Deferred()
56        self.data = {}
57
58
59    def connectionMade(self):
60        self.onConnection.callback(None)
61
62
63    def childDataReceived(self, name, bytes):
64        """
65        Record all bytes received from the child process in the C{data}
66        dictionary.  Fire C{onDataReceived} if it is not C{None}.
67        """
68        self.data[name] = self.data.get(name, '') + bytes
69        if self.onDataReceived is not None:
70            d, self.onDataReceived = self.onDataReceived, None
71            d.callback(self)
72
73
74    def processEnded(self, reason):
75        self.onCompletion.callback(reason)
76
77
78
79class StandardInputOutputTestCase(unittest.TestCase):
80
81    skip = skipWindowsNopywin32
82
83    def _spawnProcess(self, proto, sibling, *args, **kw):
84        """
85        Launch a child Python process and communicate with it using the
86        given ProcessProtocol.
87
88        @param proto: A L{ProcessProtocol} instance which will be connected
89        to the child process.
90
91        @param sibling: The basename of a file containing the Python program
92        to run in the child process.
93
94        @param *args: strings which will be passed to the child process on
95        the command line as C{argv[2:]}.
96
97        @param **kw: additional arguments to pass to L{reactor.spawnProcess}.
98
99        @return: The L{IProcessTransport} provider for the spawned process.
100        """
101        import twisted
102        subenv = dict(os.environ)
103        subenv['PYTHONPATH'] = os.pathsep.join(
104            [os.path.abspath(
105                    os.path.dirname(os.path.dirname(twisted.__file__))),
106             subenv.get('PYTHONPATH', '')
107             ])
108        args = [sys.executable,
109             filepath.FilePath(__file__).sibling(sibling).path,
110             reactor.__class__.__module__] + list(args)
111        return reactor.spawnProcess(
112            proto,
113            sys.executable,
114            args,
115            env=subenv,
116            **kw)
117
118
119    def _requireFailure(self, d, callback):
120        def cb(result):
121            self.fail("Process terminated with non-Failure: %r" % (result,))
122        def eb(err):
123            return callback(err)
124        return d.addCallbacks(cb, eb)
125
126
127    def test_loseConnection(self):
128        """
129        Verify that a protocol connected to L{StandardIO} can disconnect
130        itself using C{transport.loseConnection}.
131        """
132        errorLogFile = self.mktemp()
133        log.msg("Child process logging to " + errorLogFile)
134        p = StandardIOTestProcessProtocol()
135        d = p.onCompletion
136        self._spawnProcess(p, 'stdio_test_loseconn.py', errorLogFile)
137
138        def processEnded(reason):
139            # Copy the child's log to ours so it's more visible.
140            for line in file(errorLogFile):
141                log.msg("Child logged: " + line.rstrip())
142
143            self.failIfIn(1, p.data)
144            reason.trap(error.ProcessDone)
145        return self._requireFailure(d, processEnded)
146
147
148    def test_readConnectionLost(self):
149        """
150        When stdin is closed and the protocol connected to it implements
151        L{IHalfCloseableProtocol}, the protocol's C{readConnectionLost} method
152        is called.
153        """
154        errorLogFile = self.mktemp()
155        log.msg("Child process logging to " + errorLogFile)
156        p = StandardIOTestProcessProtocol()
157        p.onDataReceived = defer.Deferred()
158
159        def cbBytes(ignored):
160            d = p.onCompletion
161            p.transport.closeStdin()
162            return d
163        p.onDataReceived.addCallback(cbBytes)
164
165        def processEnded(reason):
166            reason.trap(error.ProcessDone)
167        d = self._requireFailure(p.onDataReceived, processEnded)
168
169        self._spawnProcess(
170            p, 'stdio_test_halfclose.py', errorLogFile)
171        return d
172
173
174    def test_lastWriteReceived(self):
175        """
176        Verify that a write made directly to stdout using L{os.write}
177        after StandardIO has finished is reliably received by the
178        process reading that stdout.
179        """
180        p = StandardIOTestProcessProtocol()
181
182        # Note: the OS X bug which prompted the addition of this test
183        # is an apparent race condition involving non-blocking PTYs.
184        # Delaying the parent process significantly increases the
185        # likelihood of the race going the wrong way.  If you need to
186        # fiddle with this code at all, uncommenting the next line
187        # will likely make your life much easier.  It is commented out
188        # because it makes the test quite slow.
189
190        # p.onConnection.addCallback(lambda ign: __import__('time').sleep(5))
191
192        try:
193            self._spawnProcess(
194                p, 'stdio_test_lastwrite.py', UNIQUE_LAST_WRITE_STRING,
195                usePTY=True)
196        except ValueError, e:
197            # Some platforms don't work with usePTY=True
198            raise unittest.SkipTest(str(e))
199
200        def processEnded(reason):
201            """
202            Asserts that the parent received the bytes written by the child
203            immediately after the child starts.
204            """
205            self.assertTrue(
206                p.data[1].endswith(UNIQUE_LAST_WRITE_STRING),
207                "Received %r from child, did not find expected bytes." % (
208                    p.data,))
209            reason.trap(error.ProcessDone)
210        return self._requireFailure(p.onCompletion, processEnded)
211
212
213    def test_hostAndPeer(self):
214        """
215        Verify that the transport of a protocol connected to L{StandardIO}
216        has C{getHost} and C{getPeer} methods.
217        """
218        p = StandardIOTestProcessProtocol()
219        d = p.onCompletion
220        self._spawnProcess(p, 'stdio_test_hostpeer.py')
221
222        def processEnded(reason):
223            host, peer = p.data[1].splitlines()
224            self.failUnless(host)
225            self.failUnless(peer)
226            reason.trap(error.ProcessDone)
227        return self._requireFailure(d, processEnded)
228
229
230    def test_write(self):
231        """
232        Verify that the C{write} method of the transport of a protocol
233        connected to L{StandardIO} sends bytes to standard out.
234        """
235        p = StandardIOTestProcessProtocol()
236        d = p.onCompletion
237
238        self._spawnProcess(p, 'stdio_test_write.py')
239
240        def processEnded(reason):
241            self.assertEqual(p.data[1], 'ok!')
242            reason.trap(error.ProcessDone)
243        return self._requireFailure(d, processEnded)
244
245
246    def test_writeSequence(self):
247        """
248        Verify that the C{writeSequence} method of the transport of a
249        protocol connected to L{StandardIO} sends bytes to standard out.
250        """
251        p = StandardIOTestProcessProtocol()
252        d = p.onCompletion
253
254        self._spawnProcess(p, 'stdio_test_writeseq.py')
255
256        def processEnded(reason):
257            self.assertEqual(p.data[1], 'ok!')
258            reason.trap(error.ProcessDone)
259        return self._requireFailure(d, processEnded)
260
261
262    def _junkPath(self):
263        junkPath = self.mktemp()
264        junkFile = file(junkPath, 'w')
265        for i in xrange(1024):
266            junkFile.write(str(i) + '\n')
267        junkFile.close()
268        return junkPath
269
270
271    def test_producer(self):
272        """
273        Verify that the transport of a protocol connected to L{StandardIO}
274        is a working L{IProducer} provider.
275        """
276        p = StandardIOTestProcessProtocol()
277        d = p.onCompletion
278
279        written = []
280        toWrite = range(100)
281
282        def connectionMade(ign):
283            if toWrite:
284                written.append(str(toWrite.pop()) + "\n")
285                proc.write(written[-1])
286                reactor.callLater(0.01, connectionMade, None)
287
288        proc = self._spawnProcess(p, 'stdio_test_producer.py')
289
290        p.onConnection.addCallback(connectionMade)
291
292        def processEnded(reason):
293            self.assertEqual(p.data[1], ''.join(written))
294            self.failIf(toWrite, "Connection lost with %d writes left to go." % (len(toWrite),))
295            reason.trap(error.ProcessDone)
296        return self._requireFailure(d, processEnded)
297
298
299    def test_consumer(self):
300        """
301        Verify that the transport of a protocol connected to L{StandardIO}
302        is a working L{IConsumer} provider.
303        """
304        p = StandardIOTestProcessProtocol()
305        d = p.onCompletion
306
307        junkPath = self._junkPath()
308
309        self._spawnProcess(p, 'stdio_test_consumer.py', junkPath)
310
311        def processEnded(reason):
312            self.assertEqual(p.data[1], file(junkPath).read())
313            reason.trap(error.ProcessDone)
314        return self._requireFailure(d, processEnded)
315
316
317    def test_normalFileStandardOut(self):
318        """
319        If L{StandardIO} is created with a file descriptor which refers to a
320        normal file (ie, a file from the filesystem), L{StandardIO.write}
321        writes bytes to that file.  In particular, it does not immediately
322        consider the file closed or call its protocol's C{connectionLost}
323        method.
324        """
325        onConnLost = defer.Deferred()
326        proto = ConnectionLostNotifyingProtocol(onConnLost)
327        path = filepath.FilePath(self.mktemp())
328        self.normal = normal = path.open('w')
329        self.addCleanup(normal.close)
330
331        kwargs = dict(stdout=normal.fileno())
332        if not platform.isWindows():
333            # Make a fake stdin so that StandardIO doesn't mess with the *real*
334            # stdin.
335            r, w = os.pipe()
336            self.addCleanup(os.close, r)
337            self.addCleanup(os.close, w)
338            kwargs['stdin'] = r
339        connection = stdio.StandardIO(proto, **kwargs)
340
341        # The reactor needs to spin a bit before it might have incorrectly
342        # decided stdout is closed.  Use this counter to keep track of how
343        # much we've let it spin.  If it closes before we expected, this
344        # counter will have a value that's too small and we'll know.
345        howMany = 5
346        count = itertools.count()
347
348        def spin():
349            for value in count:
350                if value == howMany:
351                    connection.loseConnection()
352                    return
353                connection.write(str(value))
354                break
355            reactor.callLater(0, spin)
356        reactor.callLater(0, spin)
357
358        # Once the connection is lost, make sure the counter is at the
359        # appropriate value.
360        def cbLost(reason):
361            self.assertEqual(count.next(), howMany + 1)
362            self.assertEqual(
363                path.getContent(),
364                ''.join(map(str, range(howMany))))
365        onConnLost.addCallback(cbLost)
366        return onConnLost
367
368    if platform.isWindows():
369        test_normalFileStandardOut.skip = (
370            "StandardIO does not accept stdout as an argument to Windows.  "
371            "Testing redirection to a file is therefore harder.")
372