1# -*- test-case-name: twisted.test.test_stdio -*-
2
3"""
4Windows-specific implementation of the L{twisted.internet.stdio} interface.
5"""
6
7
8import msvcrt
9import os
10
11from zope.interface import implementer
12
13import win32api  # type: ignore[import]
14
15from twisted.internet import _pollingfile, main
16from twisted.internet.interfaces import (
17    IAddress,
18    IConsumer,
19    IHalfCloseableProtocol,
20    IPushProducer,
21    ITransport,
22)
23from twisted.python.failure import Failure
24
25
26@implementer(IAddress)
27class Win32PipeAddress:
28    pass
29
30
31@implementer(ITransport, IConsumer, IPushProducer)
32class StandardIO(_pollingfile._PollingTimer):
33
34    disconnecting = False
35    disconnected = False
36
37    def __init__(self, proto, reactor=None):
38        """
39        Start talking to standard IO with the given protocol.
40
41        Also, put it stdin/stdout/stderr into binary mode.
42        """
43        if reactor is None:
44            from twisted.internet import reactor
45
46        for stdfd in range(0, 1, 2):
47            msvcrt.setmode(stdfd, os.O_BINARY)
48
49        _pollingfile._PollingTimer.__init__(self, reactor)
50        self.proto = proto
51
52        hstdin = win32api.GetStdHandle(win32api.STD_INPUT_HANDLE)
53        hstdout = win32api.GetStdHandle(win32api.STD_OUTPUT_HANDLE)
54
55        self.stdin = _pollingfile._PollableReadPipe(
56            hstdin, self.dataReceived, self.readConnectionLost
57        )
58
59        self.stdout = _pollingfile._PollableWritePipe(hstdout, self.writeConnectionLost)
60
61        self._addPollableResource(self.stdin)
62        self._addPollableResource(self.stdout)
63
64        self.proto.makeConnection(self)
65
66    def dataReceived(self, data):
67        self.proto.dataReceived(data)
68
69    def readConnectionLost(self):
70        if IHalfCloseableProtocol.providedBy(self.proto):
71            self.proto.readConnectionLost()
72        self.checkConnLost()
73
74    def writeConnectionLost(self):
75        if IHalfCloseableProtocol.providedBy(self.proto):
76            self.proto.writeConnectionLost()
77        self.checkConnLost()
78
79    connsLost = 0
80
81    def checkConnLost(self):
82        self.connsLost += 1
83        if self.connsLost >= 2:
84            self.disconnecting = True
85            self.disconnected = True
86            self.proto.connectionLost(Failure(main.CONNECTION_DONE))
87
88    # ITransport
89
90    def write(self, data):
91        self.stdout.write(data)
92
93    def writeSequence(self, seq):
94        self.stdout.write(b"".join(seq))
95
96    def loseConnection(self):
97        self.disconnecting = True
98        self.stdin.close()
99        self.stdout.close()
100
101    def getPeer(self):
102        return Win32PipeAddress()
103
104    def getHost(self):
105        return Win32PipeAddress()
106
107    # IConsumer
108
109    def registerProducer(self, producer, streaming):
110        return self.stdout.registerProducer(producer, streaming)
111
112    def unregisterProducer(self):
113        return self.stdout.unregisterProducer()
114
115    # def write() above
116
117    # IProducer
118
119    def stopProducing(self):
120        self.stdin.stopProducing()
121
122    # IPushProducer
123
124    def pauseProducing(self):
125        self.stdin.pauseProducing()
126
127    def resumeProducing(self):
128        self.stdin.resumeProducing()
129