1import inspect
2import signal
3import sys
4from functools import wraps
5import attr
6
7import async_generator
8
9from .._util import is_main_thread
10
11if False:
12    from typing import Any, TypeVar, Callable
13
14    F = TypeVar("F", bound=Callable[..., Any])
15
16# In ordinary single-threaded Python code, when you hit control-C, it raises
17# an exception and automatically does all the regular unwinding stuff.
18#
19# In Trio code, we would like hitting control-C to raise an exception and
20# automatically do all the regular unwinding stuff. In particular, we would
21# like to maintain our invariant that all tasks always run to completion (one
22# way or another), by unwinding all of them.
23#
24# But it's basically impossible to write the core task running code in such a
25# way that it can maintain this invariant in the face of KeyboardInterrupt
26# exceptions arising at arbitrary bytecode positions. Similarly, if a
27# KeyboardInterrupt happened at the wrong moment inside pretty much any of our
28# inter-task synchronization or I/O primitives, then the system state could
29# get corrupted and prevent our being able to clean up properly.
30#
31# So, we need a way to defer KeyboardInterrupt processing from these critical
32# sections.
33#
34# Things that don't work:
35#
36# - Listen for SIGINT and process it in a system task: works fine for
37#   well-behaved programs that regularly pass through the event loop, but if
38#   user-code goes into an infinite loop then it can't be interrupted. Which
39#   is unfortunate, since dealing with infinite loops is what
40#   KeyboardInterrupt is for!
41#
42# - Use pthread_sigmask to disable signal delivery during critical section:
43#   (a) windows has no pthread_sigmask, (b) python threads start with all
44#   signals unblocked, so if there are any threads around they'll receive the
45#   signal and then tell the main thread to run the handler, even if the main
46#   thread has that signal blocked.
47#
48# - Install a signal handler which checks a global variable to decide whether
49#   to raise the exception immediately (if we're in a non-critical section),
50#   or to schedule it on the event loop (if we're in a critical section). The
51#   problem here is that it's impossible to transition safely out of user code:
52#
53#     with keyboard_interrupt_enabled:
54#         msg = coro.send(value)
55#
56#   If this raises a KeyboardInterrupt, it might be because the coroutine got
57#   interrupted and has unwound... or it might be the KeyboardInterrupt
58#   arrived just *after* 'send' returned, so the coroutine is still running
59#   but we just lost the message it sent. (And worse, in our actual task
60#   runner, the send is hidden inside a utility function etc.)
61#
62# Solution:
63#
64# Mark *stack frames* as being interrupt-safe or interrupt-unsafe, and from
65# the signal handler check which kind of frame we're currently in when
66# deciding whether to raise or schedule the exception.
67#
68# There are still some cases where this can fail, like if someone hits
69# control-C while the process is in the event loop, and then it immediately
70# enters an infinite loop in user code. In this case the user has to hit
71# control-C a second time. And of course if the user code is written so that
72# it doesn't actually exit after a task crashes and everything gets cancelled,
73# then there's not much to be done. (Hitting control-C repeatedly might help,
74# but in general the solution is to kill the process some other way, just like
75# for any Python program that's written to catch and ignore
76# KeyboardInterrupt.)
77
78# We use this special string as a unique key into the frame locals dictionary.
79# The @ ensures it is not a valid identifier and can't clash with any possible
80# real local name. See: https://github.com/python-trio/trio/issues/469
81LOCALS_KEY_KI_PROTECTION_ENABLED = "@TRIO_KI_PROTECTION_ENABLED"
82
83
84# NB: according to the signal.signal docs, 'frame' can be None on entry to
85# this function:
86def ki_protection_enabled(frame):
87    while frame is not None:
88        if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals:
89            return frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED]
90        if frame.f_code.co_name == "__del__":
91            return True
92        frame = frame.f_back
93    return True
94
95
96def currently_ki_protected():
97    r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection
98    enabled.
99
100    It's surprisingly easy to think that one's :exc:`KeyboardInterrupt`
101    protection is enabled when it isn't, or vice-versa. This function tells
102    you what Trio thinks of the matter, which makes it useful for ``assert``\s
103    and unit tests.
104
105    Returns:
106      bool: True if protection is enabled, and False otherwise.
107
108    """
109    return ki_protection_enabled(sys._getframe())
110
111
112def _ki_protection_decorator(enabled):
113    def decorator(fn):
114        # In some version of Python, isgeneratorfunction returns true for
115        # coroutine functions, so we have to check for coroutine functions
116        # first.
117        if inspect.iscoroutinefunction(fn):
118
119            @wraps(fn)
120            def wrapper(*args, **kwargs):
121                # See the comment for regular generators below
122                coro = fn(*args, **kwargs)
123                coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
124                return coro
125
126            return wrapper
127        elif inspect.isgeneratorfunction(fn):
128
129            @wraps(fn)
130            def wrapper(*args, **kwargs):
131                # It's important that we inject this directly into the
132                # generator's locals, as opposed to setting it here and then
133                # doing 'yield from'. The reason is, if a generator is
134                # throw()n into, then it may magically pop to the top of the
135                # stack. And @contextmanager generators in particular are a
136                # case where we often want KI protection, and which are often
137                # thrown into! See:
138                #     https://bugs.python.org/issue29590
139                gen = fn(*args, **kwargs)
140                gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
141                return gen
142
143            return wrapper
144        elif async_generator.isasyncgenfunction(fn):
145
146            @wraps(fn)
147            def wrapper(*args, **kwargs):
148                # See the comment for regular generators above
149                agen = fn(*args, **kwargs)
150                agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
151                return agen
152
153            return wrapper
154        else:
155
156            @wraps(fn)
157            def wrapper(*args, **kwargs):
158                locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
159                return fn(*args, **kwargs)
160
161            return wrapper
162
163    return decorator
164
165
166enable_ki_protection = _ki_protection_decorator(True)  # type: Callable[[F], F]
167enable_ki_protection.__name__ = "enable_ki_protection"
168
169disable_ki_protection = _ki_protection_decorator(False)  # type: Callable[[F], F]
170disable_ki_protection.__name__ = "disable_ki_protection"
171
172
173@attr.s
174class KIManager:
175    handler = attr.ib(default=None)
176
177    def install(self, deliver_cb, restrict_keyboard_interrupt_to_checkpoints):
178        assert self.handler is None
179        if (
180            not is_main_thread()
181            or signal.getsignal(signal.SIGINT) != signal.default_int_handler
182        ):
183            return
184
185        def handler(signum, frame):
186            assert signum == signal.SIGINT
187            protection_enabled = ki_protection_enabled(frame)
188            if protection_enabled or restrict_keyboard_interrupt_to_checkpoints:
189                deliver_cb()
190            else:
191                raise KeyboardInterrupt
192
193        self.handler = handler
194        signal.signal(signal.SIGINT, handler)
195
196    def close(self):
197        if self.handler is not None:
198            if signal.getsignal(signal.SIGINT) is self.handler:
199                signal.signal(signal.SIGINT, signal.default_int_handler)
200            self.handler = None
201