1# -*- coding: utf-8 -*-
2# This module is based on the excellent work by Adam Bartoš who
3# provided a lot of what went into the implementation here in
4# the discussion to issue1602 in the Python bug tracker.
5#
6# There are some general differences in regards to how this works
7# compared to the original patches as we do not need to patch
8# the entire interpreter but just work in our little world of
9# echo and prmopt.
10import ctypes
11import io
12import os
13import sys
14import time
15import zlib
16from ctypes import byref
17from ctypes import c_char
18from ctypes import c_char_p
19from ctypes import c_int
20from ctypes import c_ssize_t
21from ctypes import c_ulong
22from ctypes import c_void_p
23from ctypes import POINTER
24from ctypes import py_object
25from ctypes import windll
26from ctypes import WinError
27from ctypes import WINFUNCTYPE
28from ctypes.wintypes import DWORD
29from ctypes.wintypes import HANDLE
30from ctypes.wintypes import LPCWSTR
31from ctypes.wintypes import LPWSTR
32
33import msvcrt
34
35from ._compat import _NonClosingTextIOWrapper
36from ._compat import PY2
37from ._compat import text_type
38
39try:
40    from ctypes import pythonapi
41
42    PyObject_GetBuffer = pythonapi.PyObject_GetBuffer
43    PyBuffer_Release = pythonapi.PyBuffer_Release
44except ImportError:
45    pythonapi = None
46
47
48c_ssize_p = POINTER(c_ssize_t)
49
50kernel32 = windll.kernel32
51GetStdHandle = kernel32.GetStdHandle
52ReadConsoleW = kernel32.ReadConsoleW
53WriteConsoleW = kernel32.WriteConsoleW
54GetConsoleMode = kernel32.GetConsoleMode
55GetLastError = kernel32.GetLastError
56GetCommandLineW = WINFUNCTYPE(LPWSTR)(("GetCommandLineW", windll.kernel32))
57CommandLineToArgvW = WINFUNCTYPE(POINTER(LPWSTR), LPCWSTR, POINTER(c_int))(
58    ("CommandLineToArgvW", windll.shell32)
59)
60LocalFree = WINFUNCTYPE(ctypes.c_void_p, ctypes.c_void_p)(
61    ("LocalFree", windll.kernel32)
62)
63
64
65STDIN_HANDLE = GetStdHandle(-10)
66STDOUT_HANDLE = GetStdHandle(-11)
67STDERR_HANDLE = GetStdHandle(-12)
68
69
70PyBUF_SIMPLE = 0
71PyBUF_WRITABLE = 1
72
73ERROR_SUCCESS = 0
74ERROR_NOT_ENOUGH_MEMORY = 8
75ERROR_OPERATION_ABORTED = 995
76
77STDIN_FILENO = 0
78STDOUT_FILENO = 1
79STDERR_FILENO = 2
80
81EOF = b"\x1a"
82MAX_BYTES_WRITTEN = 32767
83
84
85class Py_buffer(ctypes.Structure):
86    _fields_ = [
87        ("buf", c_void_p),
88        ("obj", py_object),
89        ("len", c_ssize_t),
90        ("itemsize", c_ssize_t),
91        ("readonly", c_int),
92        ("ndim", c_int),
93        ("format", c_char_p),
94        ("shape", c_ssize_p),
95        ("strides", c_ssize_p),
96        ("suboffsets", c_ssize_p),
97        ("internal", c_void_p),
98    ]
99
100    if PY2:
101        _fields_.insert(-1, ("smalltable", c_ssize_t * 2))
102
103
104# On PyPy we cannot get buffers so our ability to operate here is
105# serverly limited.
106if pythonapi is None:
107    get_buffer = None
108else:
109
110    def get_buffer(obj, writable=False):
111        buf = Py_buffer()
112        flags = PyBUF_WRITABLE if writable else PyBUF_SIMPLE
113        PyObject_GetBuffer(py_object(obj), byref(buf), flags)
114        try:
115            buffer_type = c_char * buf.len
116            return buffer_type.from_address(buf.buf)
117        finally:
118            PyBuffer_Release(byref(buf))
119
120
121class _WindowsConsoleRawIOBase(io.RawIOBase):
122    def __init__(self, handle):
123        self.handle = handle
124
125    def isatty(self):
126        io.RawIOBase.isatty(self)
127        return True
128
129
130class _WindowsConsoleReader(_WindowsConsoleRawIOBase):
131    def readable(self):
132        return True
133
134    def readinto(self, b):
135        bytes_to_be_read = len(b)
136        if not bytes_to_be_read:
137            return 0
138        elif bytes_to_be_read % 2:
139            raise ValueError(
140                "cannot read odd number of bytes from UTF-16-LE encoded console"
141            )
142
143        buffer = get_buffer(b, writable=True)
144        code_units_to_be_read = bytes_to_be_read // 2
145        code_units_read = c_ulong()
146
147        rv = ReadConsoleW(
148            HANDLE(self.handle),
149            buffer,
150            code_units_to_be_read,
151            byref(code_units_read),
152            None,
153        )
154        if GetLastError() == ERROR_OPERATION_ABORTED:
155            # wait for KeyboardInterrupt
156            time.sleep(0.1)
157        if not rv:
158            raise OSError("Windows error: {}".format(GetLastError()))
159
160        if buffer[0] == EOF:
161            return 0
162        return 2 * code_units_read.value
163
164
165class _WindowsConsoleWriter(_WindowsConsoleRawIOBase):
166    def writable(self):
167        return True
168
169    @staticmethod
170    def _get_error_message(errno):
171        if errno == ERROR_SUCCESS:
172            return "ERROR_SUCCESS"
173        elif errno == ERROR_NOT_ENOUGH_MEMORY:
174            return "ERROR_NOT_ENOUGH_MEMORY"
175        return "Windows error {}".format(errno)
176
177    def write(self, b):
178        bytes_to_be_written = len(b)
179        buf = get_buffer(b)
180        code_units_to_be_written = min(bytes_to_be_written, MAX_BYTES_WRITTEN) // 2
181        code_units_written = c_ulong()
182
183        WriteConsoleW(
184            HANDLE(self.handle),
185            buf,
186            code_units_to_be_written,
187            byref(code_units_written),
188            None,
189        )
190        bytes_written = 2 * code_units_written.value
191
192        if bytes_written == 0 and bytes_to_be_written > 0:
193            raise OSError(self._get_error_message(GetLastError()))
194        return bytes_written
195
196
197class ConsoleStream(object):
198    def __init__(self, text_stream, byte_stream):
199        self._text_stream = text_stream
200        self.buffer = byte_stream
201
202    @property
203    def name(self):
204        return self.buffer.name
205
206    def write(self, x):
207        if isinstance(x, text_type):
208            return self._text_stream.write(x)
209        try:
210            self.flush()
211        except Exception:
212            pass
213        return self.buffer.write(x)
214
215    def writelines(self, lines):
216        for line in lines:
217            self.write(line)
218
219    def __getattr__(self, name):
220        return getattr(self._text_stream, name)
221
222    def isatty(self):
223        return self.buffer.isatty()
224
225    def __repr__(self):
226        return "<ConsoleStream name={!r} encoding={!r}>".format(
227            self.name, self.encoding
228        )
229
230
231class WindowsChunkedWriter(object):
232    """
233    Wraps a stream (such as stdout), acting as a transparent proxy for all
234    attribute access apart from method 'write()' which we wrap to write in
235    limited chunks due to a Windows limitation on binary console streams.
236    """
237
238    def __init__(self, wrapped):
239        # double-underscore everything to prevent clashes with names of
240        # attributes on the wrapped stream object.
241        self.__wrapped = wrapped
242
243    def __getattr__(self, name):
244        return getattr(self.__wrapped, name)
245
246    def write(self, text):
247        total_to_write = len(text)
248        written = 0
249
250        while written < total_to_write:
251            to_write = min(total_to_write - written, MAX_BYTES_WRITTEN)
252            self.__wrapped.write(text[written : written + to_write])
253            written += to_write
254
255
256_wrapped_std_streams = set()
257
258
259def _wrap_std_stream(name):
260    # Python 2 & Windows 7 and below
261    if (
262        PY2
263        and sys.getwindowsversion()[:2] <= (6, 1)
264        and name not in _wrapped_std_streams
265    ):
266        setattr(sys, name, WindowsChunkedWriter(getattr(sys, name)))
267        _wrapped_std_streams.add(name)
268
269
270def _get_text_stdin(buffer_stream):
271    text_stream = _NonClosingTextIOWrapper(
272        io.BufferedReader(_WindowsConsoleReader(STDIN_HANDLE)),
273        "utf-16-le",
274        "strict",
275        line_buffering=True,
276    )
277    return ConsoleStream(text_stream, buffer_stream)
278
279
280def _get_text_stdout(buffer_stream):
281    text_stream = _NonClosingTextIOWrapper(
282        io.BufferedWriter(_WindowsConsoleWriter(STDOUT_HANDLE)),
283        "utf-16-le",
284        "strict",
285        line_buffering=True,
286    )
287    return ConsoleStream(text_stream, buffer_stream)
288
289
290def _get_text_stderr(buffer_stream):
291    text_stream = _NonClosingTextIOWrapper(
292        io.BufferedWriter(_WindowsConsoleWriter(STDERR_HANDLE)),
293        "utf-16-le",
294        "strict",
295        line_buffering=True,
296    )
297    return ConsoleStream(text_stream, buffer_stream)
298
299
300if PY2:
301
302    def _hash_py_argv():
303        return zlib.crc32("\x00".join(sys.argv[1:]))
304
305    _initial_argv_hash = _hash_py_argv()
306
307    def _get_windows_argv():
308        argc = c_int(0)
309        argv_unicode = CommandLineToArgvW(GetCommandLineW(), byref(argc))
310        if not argv_unicode:
311            raise WinError()
312        try:
313            argv = [argv_unicode[i] for i in range(0, argc.value)]
314        finally:
315            LocalFree(argv_unicode)
316            del argv_unicode
317
318        if not hasattr(sys, "frozen"):
319            argv = argv[1:]
320            while len(argv) > 0:
321                arg = argv[0]
322                if not arg.startswith("-") or arg == "-":
323                    break
324                argv = argv[1:]
325                if arg.startswith(("-c", "-m")):
326                    break
327
328        return argv[1:]
329
330
331_stream_factories = {
332    0: _get_text_stdin,
333    1: _get_text_stdout,
334    2: _get_text_stderr,
335}
336
337
338def _is_console(f):
339    if not hasattr(f, "fileno"):
340        return False
341
342    try:
343        fileno = f.fileno()
344    except OSError:
345        return False
346
347    handle = msvcrt.get_osfhandle(fileno)
348    return bool(GetConsoleMode(handle, byref(DWORD())))
349
350
351def _get_windows_console_stream(f, encoding, errors):
352    if (
353        get_buffer is not None
354        and encoding in ("utf-16-le", None)
355        and errors in ("strict", None)
356        and _is_console(f)
357    ):
358        func = _stream_factories.get(f.fileno())
359        if func is not None:
360            if not PY2:
361                f = getattr(f, "buffer", None)
362                if f is None:
363                    return None
364            else:
365                # If we are on Python 2 we need to set the stream that we
366                # deal with to binary mode as otherwise the exercise if a
367                # bit moot.  The same problems apply as for
368                # get_binary_stdin and friends from _compat.
369                msvcrt.setmode(f.fileno(), os.O_BINARY)
370            return func(f)
371