1import signal
2import weakref
3
4from unittest2.compatibility import wraps
5
6__unittest = True
7
8
9class _InterruptHandler(object):
10    def __init__(self, default_handler):
11        self.called = False
12        self.original_handler = default_handler
13        if isinstance(default_handler, int):
14            if default_handler == signal.SIG_DFL:
15                # Pretend it's signal.default_int_handler instead.
16                default_handler = signal.default_int_handler
17            elif default_handler == signal.SIG_IGN:
18                # Not quite the same thing as SIG_IGN, but the closest we
19                # can make it: do nothing.
20                def default_handler(unused_signum, unused_frame):
21                    pass
22            else:
23                raise TypeError("expected SIGINT signal handler to be "
24                                "signal.SIG_IGN, signal.SIG_DFL, or a "
25                                "callable object")
26        self.default_handler = default_handler
27
28    def __call__(self, signum, frame):
29        installed_handler = signal.getsignal(signal.SIGINT)
30        if installed_handler is not self:
31            # if we aren't the installed handler, then delegate immediately
32            # to the default handler
33            self.default_handler(signum, frame)
34
35        if self.called:
36            self.default_handler(signum, frame)
37        self.called = True
38        for result in _results.keys():
39            result.stop()
40
41_results = weakref.WeakKeyDictionary()
42def registerResult(result):
43    _results[result] = 1
44
45def removeResult(result):
46    return bool(_results.pop(result, None))
47
48_interrupt_handler = None
49def installHandler():
50    global _interrupt_handler
51    if _interrupt_handler is None:
52        default_handler = signal.getsignal(signal.SIGINT)
53        _interrupt_handler = _InterruptHandler(default_handler)
54        signal.signal(signal.SIGINT, _interrupt_handler)
55
56
57def removeHandler(method=None):
58    if method is not None:
59        @wraps(method)
60        def inner(*args, **kwargs):
61            initial = signal.getsignal(signal.SIGINT)
62            removeHandler()
63            try:
64                return method(*args, **kwargs)
65            finally:
66                signal.signal(signal.SIGINT, initial)
67        return inner
68
69    global _interrupt_handler
70    if _interrupt_handler is not None:
71        signal.signal(signal.SIGINT, _interrupt_handler.original_handler)
72