1import signal
2import weakref
3
4from unittest2.compatibility import wraps
5
6__unittest = True
7
8
9class _InterruptHandler(object):
10
11    def __init__(self, default_handler):
12        self.called = False
13        self.default_handler = default_handler
14
15    def __call__(self, signum, frame):
16        installed_handler = signal.getsignal(signal.SIGINT)
17        if installed_handler is not self:
18            # if we aren't the installed handler, then delegate immediately
19            # to the default handler
20            self.default_handler(signum, frame)
21
22        if self.called:
23            self.default_handler(signum, frame)
24        self.called = True
25        for result in _results.keys():
26            result.stop()
27
28_results = weakref.WeakKeyDictionary()
29
30
31def registerResult(result):
32    _results[result] = 1
33
34
35def removeResult(result):
36    return bool(_results.pop(result, None))
37
38_interrupt_handler = None
39
40
41def installHandler():
42    global _interrupt_handler
43    if _interrupt_handler is None:
44        default_handler = signal.getsignal(signal.SIGINT)
45        _interrupt_handler = _InterruptHandler(default_handler)
46        signal.signal(signal.SIGINT, _interrupt_handler)
47
48
49def removeHandler(method=None):
50    if method is not None:
51        @wraps(method)
52        def inner(*args, **kwargs):
53            initial = signal.getsignal(signal.SIGINT)
54            removeHandler()
55            try:
56                return method(*args, **kwargs)
57            finally:
58                signal.signal(signal.SIGINT, initial)
59        return inner
60
61    global _interrupt_handler
62    if _interrupt_handler is not None:
63        signal.signal(signal.SIGINT, _interrupt_handler.default_handler)
64