1# Copyright (c) Twisted Matrix Laboratories. 2# See LICENSE for details. 3 4""" 5Tests for L{twisted.internet.stdio}. 6""" 7 8import os, sys, itertools 9 10from twisted.trial import unittest 11from twisted.python import filepath, log 12from twisted.python.runtime import platform 13from twisted.internet import error, defer, protocol, stdio, reactor 14from twisted.test.test_tcp import ConnectionLostNotifyingProtocol 15 16 17# A short string which is intended to appear here and nowhere else, 18# particularly not in any random garbage output CPython unavoidable 19# generates (such as in warning text and so forth). This is searched 20# for in the output from stdio_test_lastwrite.py and if it is found at 21# the end, the functionality works. 22UNIQUE_LAST_WRITE_STRING = 'xyz123abc Twisted is great!' 23 24skipWindowsNopywin32 = None 25if platform.isWindows(): 26 try: 27 import win32process 28 except ImportError: 29 skipWindowsNopywin32 = ("On windows, spawnProcess is not available " 30 "in the absence of win32process.") 31 32 33class StandardIOTestProcessProtocol(protocol.ProcessProtocol): 34 """ 35 Test helper for collecting output from a child process and notifying 36 something when it exits. 37 38 @ivar onConnection: A L{defer.Deferred} which will be called back with 39 C{None} when the connection to the child process is established. 40 41 @ivar onCompletion: A L{defer.Deferred} which will be errbacked with the 42 failure associated with the child process exiting when it exits. 43 44 @ivar onDataReceived: A L{defer.Deferred} which will be called back with 45 this instance whenever C{childDataReceived} is called, or C{None} to 46 suppress these callbacks. 47 48 @ivar data: A C{dict} mapping file descriptors to strings containing all 49 bytes received from the child process on each file descriptor. 50 """ 51 onDataReceived = None 52 53 def __init__(self): 54 self.onConnection = defer.Deferred() 55 self.onCompletion = defer.Deferred() 56 self.data = {} 57 58 59 def connectionMade(self): 60 self.onConnection.callback(None) 61 62 63 def childDataReceived(self, name, bytes): 64 """ 65 Record all bytes received from the child process in the C{data} 66 dictionary. Fire C{onDataReceived} if it is not C{None}. 67 """ 68 self.data[name] = self.data.get(name, '') + bytes 69 if self.onDataReceived is not None: 70 d, self.onDataReceived = self.onDataReceived, None 71 d.callback(self) 72 73 74 def processEnded(self, reason): 75 self.onCompletion.callback(reason) 76 77 78 79class StandardInputOutputTestCase(unittest.TestCase): 80 81 skip = skipWindowsNopywin32 82 83 def _spawnProcess(self, proto, sibling, *args, **kw): 84 """ 85 Launch a child Python process and communicate with it using the 86 given ProcessProtocol. 87 88 @param proto: A L{ProcessProtocol} instance which will be connected 89 to the child process. 90 91 @param sibling: The basename of a file containing the Python program 92 to run in the child process. 93 94 @param *args: strings which will be passed to the child process on 95 the command line as C{argv[2:]}. 96 97 @param **kw: additional arguments to pass to L{reactor.spawnProcess}. 98 99 @return: The L{IProcessTransport} provider for the spawned process. 100 """ 101 import twisted 102 subenv = dict(os.environ) 103 subenv['PYTHONPATH'] = os.pathsep.join( 104 [os.path.abspath( 105 os.path.dirname(os.path.dirname(twisted.__file__))), 106 subenv.get('PYTHONPATH', '') 107 ]) 108 args = [sys.executable, 109 filepath.FilePath(__file__).sibling(sibling).path, 110 reactor.__class__.__module__] + list(args) 111 return reactor.spawnProcess( 112 proto, 113 sys.executable, 114 args, 115 env=subenv, 116 **kw) 117 118 119 def _requireFailure(self, d, callback): 120 def cb(result): 121 self.fail("Process terminated with non-Failure: %r" % (result,)) 122 def eb(err): 123 return callback(err) 124 return d.addCallbacks(cb, eb) 125 126 127 def test_loseConnection(self): 128 """ 129 Verify that a protocol connected to L{StandardIO} can disconnect 130 itself using C{transport.loseConnection}. 131 """ 132 errorLogFile = self.mktemp() 133 log.msg("Child process logging to " + errorLogFile) 134 p = StandardIOTestProcessProtocol() 135 d = p.onCompletion 136 self._spawnProcess(p, 'stdio_test_loseconn.py', errorLogFile) 137 138 def processEnded(reason): 139 # Copy the child's log to ours so it's more visible. 140 for line in file(errorLogFile): 141 log.msg("Child logged: " + line.rstrip()) 142 143 self.failIfIn(1, p.data) 144 reason.trap(error.ProcessDone) 145 return self._requireFailure(d, processEnded) 146 147 148 def test_readConnectionLost(self): 149 """ 150 When stdin is closed and the protocol connected to it implements 151 L{IHalfCloseableProtocol}, the protocol's C{readConnectionLost} method 152 is called. 153 """ 154 errorLogFile = self.mktemp() 155 log.msg("Child process logging to " + errorLogFile) 156 p = StandardIOTestProcessProtocol() 157 p.onDataReceived = defer.Deferred() 158 159 def cbBytes(ignored): 160 d = p.onCompletion 161 p.transport.closeStdin() 162 return d 163 p.onDataReceived.addCallback(cbBytes) 164 165 def processEnded(reason): 166 reason.trap(error.ProcessDone) 167 d = self._requireFailure(p.onDataReceived, processEnded) 168 169 self._spawnProcess( 170 p, 'stdio_test_halfclose.py', errorLogFile) 171 return d 172 173 174 def test_lastWriteReceived(self): 175 """ 176 Verify that a write made directly to stdout using L{os.write} 177 after StandardIO has finished is reliably received by the 178 process reading that stdout. 179 """ 180 p = StandardIOTestProcessProtocol() 181 182 # Note: the OS X bug which prompted the addition of this test 183 # is an apparent race condition involving non-blocking PTYs. 184 # Delaying the parent process significantly increases the 185 # likelihood of the race going the wrong way. If you need to 186 # fiddle with this code at all, uncommenting the next line 187 # will likely make your life much easier. It is commented out 188 # because it makes the test quite slow. 189 190 # p.onConnection.addCallback(lambda ign: __import__('time').sleep(5)) 191 192 try: 193 self._spawnProcess( 194 p, 'stdio_test_lastwrite.py', UNIQUE_LAST_WRITE_STRING, 195 usePTY=True) 196 except ValueError, e: 197 # Some platforms don't work with usePTY=True 198 raise unittest.SkipTest(str(e)) 199 200 def processEnded(reason): 201 """ 202 Asserts that the parent received the bytes written by the child 203 immediately after the child starts. 204 """ 205 self.assertTrue( 206 p.data[1].endswith(UNIQUE_LAST_WRITE_STRING), 207 "Received %r from child, did not find expected bytes." % ( 208 p.data,)) 209 reason.trap(error.ProcessDone) 210 return self._requireFailure(p.onCompletion, processEnded) 211 212 213 def test_hostAndPeer(self): 214 """ 215 Verify that the transport of a protocol connected to L{StandardIO} 216 has C{getHost} and C{getPeer} methods. 217 """ 218 p = StandardIOTestProcessProtocol() 219 d = p.onCompletion 220 self._spawnProcess(p, 'stdio_test_hostpeer.py') 221 222 def processEnded(reason): 223 host, peer = p.data[1].splitlines() 224 self.failUnless(host) 225 self.failUnless(peer) 226 reason.trap(error.ProcessDone) 227 return self._requireFailure(d, processEnded) 228 229 230 def test_write(self): 231 """ 232 Verify that the C{write} method of the transport of a protocol 233 connected to L{StandardIO} sends bytes to standard out. 234 """ 235 p = StandardIOTestProcessProtocol() 236 d = p.onCompletion 237 238 self._spawnProcess(p, 'stdio_test_write.py') 239 240 def processEnded(reason): 241 self.assertEqual(p.data[1], 'ok!') 242 reason.trap(error.ProcessDone) 243 return self._requireFailure(d, processEnded) 244 245 246 def test_writeSequence(self): 247 """ 248 Verify that the C{writeSequence} method of the transport of a 249 protocol connected to L{StandardIO} sends bytes to standard out. 250 """ 251 p = StandardIOTestProcessProtocol() 252 d = p.onCompletion 253 254 self._spawnProcess(p, 'stdio_test_writeseq.py') 255 256 def processEnded(reason): 257 self.assertEqual(p.data[1], 'ok!') 258 reason.trap(error.ProcessDone) 259 return self._requireFailure(d, processEnded) 260 261 262 def _junkPath(self): 263 junkPath = self.mktemp() 264 junkFile = file(junkPath, 'w') 265 for i in xrange(1024): 266 junkFile.write(str(i) + '\n') 267 junkFile.close() 268 return junkPath 269 270 271 def test_producer(self): 272 """ 273 Verify that the transport of a protocol connected to L{StandardIO} 274 is a working L{IProducer} provider. 275 """ 276 p = StandardIOTestProcessProtocol() 277 d = p.onCompletion 278 279 written = [] 280 toWrite = range(100) 281 282 def connectionMade(ign): 283 if toWrite: 284 written.append(str(toWrite.pop()) + "\n") 285 proc.write(written[-1]) 286 reactor.callLater(0.01, connectionMade, None) 287 288 proc = self._spawnProcess(p, 'stdio_test_producer.py') 289 290 p.onConnection.addCallback(connectionMade) 291 292 def processEnded(reason): 293 self.assertEqual(p.data[1], ''.join(written)) 294 self.failIf(toWrite, "Connection lost with %d writes left to go." % (len(toWrite),)) 295 reason.trap(error.ProcessDone) 296 return self._requireFailure(d, processEnded) 297 298 299 def test_consumer(self): 300 """ 301 Verify that the transport of a protocol connected to L{StandardIO} 302 is a working L{IConsumer} provider. 303 """ 304 p = StandardIOTestProcessProtocol() 305 d = p.onCompletion 306 307 junkPath = self._junkPath() 308 309 self._spawnProcess(p, 'stdio_test_consumer.py', junkPath) 310 311 def processEnded(reason): 312 self.assertEqual(p.data[1], file(junkPath).read()) 313 reason.trap(error.ProcessDone) 314 return self._requireFailure(d, processEnded) 315 316 317 def test_normalFileStandardOut(self): 318 """ 319 If L{StandardIO} is created with a file descriptor which refers to a 320 normal file (ie, a file from the filesystem), L{StandardIO.write} 321 writes bytes to that file. In particular, it does not immediately 322 consider the file closed or call its protocol's C{connectionLost} 323 method. 324 """ 325 onConnLost = defer.Deferred() 326 proto = ConnectionLostNotifyingProtocol(onConnLost) 327 path = filepath.FilePath(self.mktemp()) 328 self.normal = normal = path.open('w') 329 self.addCleanup(normal.close) 330 331 kwargs = dict(stdout=normal.fileno()) 332 if not platform.isWindows(): 333 # Make a fake stdin so that StandardIO doesn't mess with the *real* 334 # stdin. 335 r, w = os.pipe() 336 self.addCleanup(os.close, r) 337 self.addCleanup(os.close, w) 338 kwargs['stdin'] = r 339 connection = stdio.StandardIO(proto, **kwargs) 340 341 # The reactor needs to spin a bit before it might have incorrectly 342 # decided stdout is closed. Use this counter to keep track of how 343 # much we've let it spin. If it closes before we expected, this 344 # counter will have a value that's too small and we'll know. 345 howMany = 5 346 count = itertools.count() 347 348 def spin(): 349 for value in count: 350 if value == howMany: 351 connection.loseConnection() 352 return 353 connection.write(str(value)) 354 break 355 reactor.callLater(0, spin) 356 reactor.callLater(0, spin) 357 358 # Once the connection is lost, make sure the counter is at the 359 # appropriate value. 360 def cbLost(reason): 361 self.assertEqual(count.next(), howMany + 1) 362 self.assertEqual( 363 path.getContent(), 364 ''.join(map(str, range(howMany)))) 365 onConnLost.addCallback(cbLost) 366 return onConnLost 367 368 if platform.isWindows(): 369 test_normalFileStandardOut.skip = ( 370 "StandardIO does not accept stdout as an argument to Windows. " 371 "Testing redirection to a file is therefore harder.") 372