1# This module is based on the excellent work by Adam Bartoš who
2# provided a lot of what went into the implementation here in
3# the discussion to issue1602 in the Python bug tracker.
4#
5# There are some general differences in regards to how this works
6# compared to the original patches as we do not need to patch
7# the entire interpreter but just work in our little world of
8# echo and prompt.
9import io
10import sys
11import time
12import typing as t
13from ctypes import byref
14from ctypes import c_char
15from ctypes import c_char_p
16from ctypes import c_int
17from ctypes import c_ssize_t
18from ctypes import c_ulong
19from ctypes import c_void_p
20from ctypes import POINTER
21from ctypes import py_object
22from ctypes import Structure
23from ctypes.wintypes import DWORD
24from ctypes.wintypes import HANDLE
25from ctypes.wintypes import LPCWSTR
26from ctypes.wintypes import LPWSTR
27
28from ._compat import _NonClosingTextIOWrapper
29
30assert sys.platform == "win32"
31import msvcrt  # noqa: E402
32from ctypes import windll  # noqa: E402
33from ctypes import WINFUNCTYPE  # noqa: E402
34
35c_ssize_p = POINTER(c_ssize_t)
36
37kernel32 = windll.kernel32
38GetStdHandle = kernel32.GetStdHandle
39ReadConsoleW = kernel32.ReadConsoleW
40WriteConsoleW = kernel32.WriteConsoleW
41GetConsoleMode = kernel32.GetConsoleMode
42GetLastError = kernel32.GetLastError
43GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32))
44CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))(
45    ("CommandLineToArgvW", windll.shell32)
46)
47LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32))
48
49STDIN_HANDLE = GetStdHandle(-10)
50STDOUT_HANDLE = GetStdHandle(-11)
51STDERR_HANDLE = GetStdHandle(-12)
52
53PyBUF_SIMPLE = 0
54PyBUF_WRITABLE = 1
55
56ERROR_SUCCESS = 0
57ERROR_NOT_ENOUGH_MEMORY = 8
58ERROR_OPERATION_ABORTED = 995
59
60STDIN_FILENO = 0
61STDOUT_FILENO = 1
62STDERR_FILENO = 2
63
64EOF = b"\x1a"
65MAX_BYTES_WRITTEN = 32767
66
67try:
68    from ctypes import pythonapi
69except ImportError:
70    # On PyPy we cannot get buffers so our ability to operate here is
71    # severely limited.
72    get_buffer = None
73else:
74
75    class Py_buffer(Structure):
76        _fields_ = [
77            ("buf", c_void_p),
78            ("obj", py_object),
79            ("len", c_ssize_t),
80            ("itemsize", c_ssize_t),
81            ("readonly", c_int),
82            ("ndim", c_int),
83            ("format", c_char_p),
84            ("shape", c_ssize_p),
85            ("strides", c_ssize_p),
86            ("suboffsets", c_ssize_p),
87            ("internal", c_void_p),
88        ]
89
90    PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
91    PyBuffer_Release = pythonapi.PyBuffer_Release
92
93    def get_buffer(obj, writable=False):
94        buf = Py_buffer()
95        flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
96        PyObject_GetBuffer(py_object(obj), byref(buf), flags)
97
98        try:
99            buffer_type = c_char * buf.len
100            return buffer_type.from_address(buf.buf)
101        finally:
102            PyBuffer_Release(byref(buf))
103
104
105class _WindowsConsoleRawIOBase(io.RawIOBase):
106    def __init__(self, handle):
107        self.handle = handle
108
109    def isatty(self):
110        super().isatty()
111        return True
112
113
114class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
115    def readable(self):
116        return True
117
118    def readinto(self, b):
119        bytes_to_be_read = len(b)
120        if not bytes_to_be_read:
121            return 0
122        elif bytes_to_be_read % 2:
123            raise ValueError(
124                "cannot read odd number of bytes from UTF-16-LE encoded console"
125            )
126
127        buffer = get_buffer(b, writable=True)
128        code_units_to_be_read = bytes_to_be_read // 2
129        code_units_read = c_ulong()
130
131        rv = ReadConsoleW(
132            HANDLE(self.handle),
133            buffer,
134            code_units_to_be_read,
135            byref(code_units_read),
136            None,
137        )
138        if GetLastError() == ERROR_OPERATION_ABORTED:
139            # wait for KeyboardInterrupt
140            time.sleep(0.1)
141        if not rv:
142            raise OSError(f"Windows error: {GetLastError()}")
143
144        if buffer[0] == EOF:
145            return 0
146        return 2 * code_units_read.value
147
148
149class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
150    def writable(self):
151        return True
152
153    @staticmethod
154    def _get_error_message(errno):
155        if errno == ERROR_SUCCESS:
156            return "ERROR_SUCCESS"
157        elif errno == ERROR_NOT_ENOUGH_MEMORY:
158            return "ERROR_NOT_ENOUGH_MEMORY"
159        return f"Windows error {errno}"
160
161    def write(self, b):
162        bytes_to_be_written = len(b)
163        buf = get_buffer(b)
164        code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
165        code_units_written = c_ulong()
166
167        WriteConsoleW(
168            HANDLE(self.handle),
169            buf,
170            code_units_to_be_written,
171            byref(code_units_written),
172            None,
173        )
174        bytes_written = 2 * code_units_written.value
175
176        if bytes_written == 0 and bytes_to_be_written > 0:
177            raise OSError(self._get_error_message(GetLastError()))
178        return bytes_written
179
180
181class ConsoleStream:
182    def __init__(self, text_stream: t.TextIO, byte_stream: t.BinaryIO) -> None:
183        self._text_stream = text_stream
184        self.buffer = byte_stream
185
186    @property
187    def name(self) -> str:
188        return self.buffer.name
189
190    def write(self, x: t.AnyStr) -> int:
191        if isinstance(x, str):
192            return self._text_stream.write(x)
193        try:
194            self.flush()
195        except Exception:
196            pass
197        return self.buffer.write(x)
198
199    def writelines(self, lines: t.Iterable[t.AnyStr]) -> None:
200        for line in lines:
201            self.write(line)
202
203    def __getattr__(self, name: str) -> t.Any:
204        return getattr(self._text_stream, name)
205
206    def isatty(self) -> bool:
207        return self.buffer.isatty()
208
209    def __repr__(self):
210        return f"<ConsoleStream name={self.name!r} encoding={self.encoding!r}>"
211
212
213def _get_text_stdin(buffer_stream: t.BinaryIO) -> t.TextIO:
214    text_stream = _NonClosingTextIOWrapper(
215        io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)),
216        "utf-16-le",
217        "strict",
218        line_buffering=True,
219    )
220    return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
221
222
223def _get_text_stdout(buffer_stream: t.BinaryIO) -> t.TextIO:
224    text_stream = _NonClosingTextIOWrapper(
225        io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)),
226        "utf-16-le",
227        "strict",
228        line_buffering=True,
229    )
230    return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
231
232
233def _get_text_stderr(buffer_stream: t.BinaryIO) -> t.TextIO:
234    text_stream = _NonClosingTextIOWrapper(
235        io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)),
236        "utf-16-le",
237        "strict",
238        line_buffering=True,
239    )
240    return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream))
241
242
243_stream_factories: t.Mapping[int, t.Callable[[t.BinaryIO], t.TextIO]] = {
244    0: _get_text_stdin,
245    1: _get_text_stdout,
246    2: _get_text_stderr,
247}
248
249
250def _is_console(f: t.TextIO) -> bool:
251    if not hasattr(f, "fileno"):
252        return False
253
254    try:
255        fileno = f.fileno()
256    except (OSError, io.UnsupportedOperation):
257        return False
258
259    handle = msvcrt.get_osfhandle(fileno)
260    return bool(GetConsoleMode(handle, byref(DWORD())))
261
262
263def _get_windows_console_stream(
264    f: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str]
265) -> t.Optional[t.TextIO]:
266    if (
267        get_buffer is not None
268        and encoding in {"utf-16-le", None}
269        and errors in {"strict", None}
270        and _is_console(f)
271    ):
272        func = _stream_factories.get(f.fileno())
273        if func is not None:
274            b = getattr(f, "buffer", None)
275
276            if b is None:
277                return None
278
279            return func(b)
280