1# This module is based on the excellent work by Adam Bartoš who 2# provided a lot of what went into the implementation here in 3# the discussion to issue1602 in the Python bug tracker. 4# 5# There are some general differences in regards to how this works 6# compared to the original patches as we do not need to patch 7# the entire interpreter but just work in our little world of 8# echo and prompt. 9import io 10import sys 11import time 12import typing as t 13from ctypes import byref 14from ctypes import c_char 15from ctypes import c_char_p 16from ctypes import c_int 17from ctypes import c_ssize_t 18from ctypes import c_ulong 19from ctypes import c_void_p 20from ctypes import POINTER 21from ctypes import py_object 22from ctypes import Structure 23from ctypes.wintypes import DWORD 24from ctypes.wintypes import HANDLE 25from ctypes.wintypes import LPCWSTR 26from ctypes.wintypes import LPWSTR 27 28from ._compat import _NonClosingTextIOWrapper 29 30assert sys.platform == "win32" 31import msvcrt # noqa: E402 32from ctypes import windll # noqa: E402 33from ctypes import WINFUNCTYPE # noqa: E402 34 35c_ssize_p = POINTER(c_ssize_t) 36 37kernel32 = windll.kernel32 38GetStdHandle = kernel32.GetStdHandle 39ReadConsoleW = kernel32.ReadConsoleW 40WriteConsoleW = kernel32.WriteConsoleW 41GetConsoleMode = kernel32.GetConsoleMode 42GetLastError = kernel32.GetLastError 43GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32)) 44CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))( 45 ("CommandLineToArgvW", windll.shell32) 46) 47LocalFree = WINFUNCTYPE(c_void_p, c_void_p)(("LocalFree", windll.kernel32)) 48 49STDIN_HANDLE = GetStdHandle(-10) 50STDOUT_HANDLE = GetStdHandle(-11) 51STDERR_HANDLE = GetStdHandle(-12) 52 53PyBUF_SIMPLE = 0 54PyBUF_WRITABLE = 1 55 56ERROR_SUCCESS = 0 57ERROR_NOT_ENOUGH_MEMORY = 8 58ERROR_OPERATION_ABORTED = 995 59 60STDIN_FILENO = 0 61STDOUT_FILENO = 1 62STDERR_FILENO = 2 63 64EOF = b"\x1a" 65MAX_BYTES_WRITTEN = 32767 66 67try: 68 from ctypes import pythonapi 69except ImportError: 70 # On PyPy we cannot get buffers so our ability to operate here is 71 # severely limited. 72 get_buffer = None 73else: 74 75 class Py_buffer(Structure): 76 _fields_ = [ 77 ("buf", c_void_p), 78 ("obj", py_object), 79 ("len", c_ssize_t), 80 ("itemsize", c_ssize_t), 81 ("readonly", c_int), 82 ("ndim", c_int), 83 ("format", c_char_p), 84 ("shape", c_ssize_p), 85 ("strides", c_ssize_p), 86 ("suboffsets", c_ssize_p), 87 ("internal", c_void_p), 88 ] 89 90 PyObject_GetBuffer = pythonapi.PyObject_GetBuffer 91 PyBuffer_Release = pythonapi.PyBuffer_Release 92 93 def get_buffer(obj, writable=False): 94 buf = Py_buffer() 95 flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE 96 PyObject_GetBuffer(py_object(obj), byref(buf), flags) 97 98 try: 99 buffer_type = c_char * buf.len 100 return buffer_type.from_address(buf.buf) 101 finally: 102 PyBuffer_Release(byref(buf)) 103 104 105class _WindowsConsoleRawIOBase(io.RawIOBase): 106 def __init__(self, handle): 107 self.handle = handle 108 109 def isatty(self): 110 super().isatty() 111 return True 112 113 114class _WindowsConsoleReader(_WindowsConsoleRawIOBase): 115 def readable(self): 116 return True 117 118 def readinto(self, b): 119 bytes_to_be_read = len(b) 120 if not bytes_to_be_read: 121 return 0 122 elif bytes_to_be_read % 2: 123 raise ValueError( 124 "cannot read odd number of bytes from UTF-16-LE encoded console" 125 ) 126 127 buffer = get_buffer(b, writable=True) 128 code_units_to_be_read = bytes_to_be_read // 2 129 code_units_read = c_ulong() 130 131 rv = ReadConsoleW( 132 HANDLE(self.handle), 133 buffer, 134 code_units_to_be_read, 135 byref(code_units_read), 136 None, 137 ) 138 if GetLastError() == ERROR_OPERATION_ABORTED: 139 # wait for KeyboardInterrupt 140 time.sleep(0.1) 141 if not rv: 142 raise OSError(f"Windows error: {GetLastError()}") 143 144 if buffer[0] == EOF: 145 return 0 146 return 2 * code_units_read.value 147 148 149class _WindowsConsoleWriter(_WindowsConsoleRawIOBase): 150 def writable(self): 151 return True 152 153 @staticmethod 154 def _get_error_message(errno): 155 if errno == ERROR_SUCCESS: 156 return "ERROR_SUCCESS" 157 elif errno == ERROR_NOT_ENOUGH_MEMORY: 158 return "ERROR_NOT_ENOUGH_MEMORY" 159 return f"Windows error {errno}" 160 161 def write(self, b): 162 bytes_to_be_written = len(b) 163 buf = get_buffer(b) 164 code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2 165 code_units_written = c_ulong() 166 167 WriteConsoleW( 168 HANDLE(self.handle), 169 buf, 170 code_units_to_be_written, 171 byref(code_units_written), 172 None, 173 ) 174 bytes_written = 2 * code_units_written.value 175 176 if bytes_written == 0 and bytes_to_be_written > 0: 177 raise OSError(self._get_error_message(GetLastError())) 178 return bytes_written 179 180 181class ConsoleStream: 182 def __init__(self, text_stream: t.TextIO, byte_stream: t.BinaryIO) -> None: 183 self._text_stream = text_stream 184 self.buffer = byte_stream 185 186 @property 187 def name(self) -> str: 188 return self.buffer.name 189 190 def write(self, x: t.AnyStr) -> int: 191 if isinstance(x, str): 192 return self._text_stream.write(x) 193 try: 194 self.flush() 195 except Exception: 196 pass 197 return self.buffer.write(x) 198 199 def writelines(self, lines: t.Iterable[t.AnyStr]) -> None: 200 for line in lines: 201 self.write(line) 202 203 def __getattr__(self, name: str) -> t.Any: 204 return getattr(self._text_stream, name) 205 206 def isatty(self) -> bool: 207 return self.buffer.isatty() 208 209 def __repr__(self): 210 return f"<ConsoleStream name={self.name!r} encoding={self.encoding!r}>" 211 212 213def _get_text_stdin(buffer_stream: t.BinaryIO) -> t.TextIO: 214 text_stream = _NonClosingTextIOWrapper( 215 io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)), 216 "utf-16-le", 217 "strict", 218 line_buffering=True, 219 ) 220 return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) 221 222 223def _get_text_stdout(buffer_stream: t.BinaryIO) -> t.TextIO: 224 text_stream = _NonClosingTextIOWrapper( 225 io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)), 226 "utf-16-le", 227 "strict", 228 line_buffering=True, 229 ) 230 return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) 231 232 233def _get_text_stderr(buffer_stream: t.BinaryIO) -> t.TextIO: 234 text_stream = _NonClosingTextIOWrapper( 235 io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)), 236 "utf-16-le", 237 "strict", 238 line_buffering=True, 239 ) 240 return t.cast(t.TextIO, ConsoleStream(text_stream, buffer_stream)) 241 242 243_stream_factories: t.Mapping[int, t.Callable[[t.BinaryIO], t.TextIO]] = { 244 0: _get_text_stdin, 245 1: _get_text_stdout, 246 2: _get_text_stderr, 247} 248 249 250def _is_console(f: t.TextIO) -> bool: 251 if not hasattr(f, "fileno"): 252 return False 253 254 try: 255 fileno = f.fileno() 256 except (OSError, io.UnsupportedOperation): 257 return False 258 259 handle = msvcrt.get_osfhandle(fileno) 260 return bool(GetConsoleMode(handle, byref(DWORD()))) 261 262 263def _get_windows_console_stream( 264 f: t.TextIO, encoding: t.Optional[str], errors: t.Optional[str] 265) -> t.Optional[t.TextIO]: 266 if ( 267 get_buffer is not None 268 and encoding in {"utf-16-le", None} 269 and errors in {"strict", None} 270 and _is_console(f) 271 ): 272 func = _stream_factories.get(f.fileno()) 273 if func is not None: 274 b = getattr(f, "buffer", None) 275 276 if b is None: 277 return None 278 279 return func(b) 280