1# -*- coding: utf-8 -*-
2from __future__ import unicode_literals
3from __future__ import division
4from __future__ import print_function
5from __future__ import absolute_import
6
7from functools import wraps
8from datetime import datetime, timedelta
9from typing import AnyStr, Iterable
10
11STATE_CLOSED = 'closed'
12STATE_OPEN = 'open'
13STATE_HALF_OPEN = 'half_open'
14
15
16class CircuitBreaker(object):
17    FAILURE_THRESHOLD = 5
18    RECOVERY_TIMEOUT = 30
19    EXPECTED_EXCEPTION = Exception
20    FALLBACK_FUNCTION = None
21
22    def __init__(self,
23                 failure_threshold=None,
24                 recovery_timeout=None,
25                 expected_exception=None,
26                 name=None,
27                 fallback_function=None):
28        self._last_failure = None
29        self._failure_count = 0
30        self._failure_threshold = failure_threshold or self.FAILURE_THRESHOLD
31        self._recovery_timeout = recovery_timeout or self.RECOVERY_TIMEOUT
32        self._expected_exception = expected_exception or self.EXPECTED_EXCEPTION
33        self._fallback_function = fallback_function or self.FALLBACK_FUNCTION
34        self._name = name
35        self._state = STATE_CLOSED
36        self._opened = datetime.utcnow()
37
38    def __call__(self, wrapped):
39        return self.decorate(wrapped)
40
41    def decorate(self, function):
42        """
43        Applies the circuit breaker to a function
44        """
45        if self._name is None:
46            self._name = function.__name__
47
48        CircuitBreakerMonitor.register(self)
49
50        @wraps(function)
51        def wrapper(*args, **kwargs):
52            return self.call(function, *args, **kwargs)
53
54        return wrapper
55
56    def call(self, func, *args, **kwargs):
57        """
58        Calls the decorated function and applies the circuit breaker
59        rules on success or failure
60        :param func: Decorated function
61        """
62        if self.opened:
63            if self.fallback_function:
64                return self.fallback_function(*args, **kwargs)
65            raise CircuitBreakerError(self)
66        try:
67            result = func(*args, **kwargs)
68        except self._expected_exception as e:
69            self._last_failure = e
70            self.__call_failed()
71            raise
72
73        self.__call_succeeded()
74        return result
75
76    def __call_succeeded(self):
77        """
78        Close circuit after successful execution and reset failure count
79        """
80        self._state = STATE_CLOSED
81        self._last_failure = None
82        self._failure_count = 0
83
84    def __call_failed(self):
85        """
86        Count failure and open circuit, if threshold has been reached
87        """
88        self._failure_count += 1
89        if self._failure_count >= self._failure_threshold:
90            self._state = STATE_OPEN
91            self._opened = datetime.utcnow()
92
93    @property
94    def state(self):
95        if self._state == STATE_OPEN and self.open_remaining <= 0:
96            return STATE_HALF_OPEN
97        return self._state
98
99    @property
100    def open_until(self):
101        """
102        The datetime, when the circuit breaker will try to recover
103        :return: datetime
104        """
105        return self._opened + timedelta(seconds=self._recovery_timeout)
106
107    @property
108    def open_remaining(self):
109        """
110        Number of seconds remaining, the circuit breaker stays in OPEN state
111        :return: int
112        """
113        return (self.open_until - datetime.utcnow()).total_seconds()
114
115    @property
116    def failure_count(self):
117        return self._failure_count
118
119    @property
120    def closed(self):
121        return self.state == STATE_CLOSED
122
123    @property
124    def opened(self):
125        return self.state == STATE_OPEN
126
127    @property
128    def name(self):
129        return self._name
130
131    @property
132    def last_failure(self):
133        return self._last_failure
134
135    @property
136    def fallback_function(self):
137        return self._fallback_function
138
139    def __str__(self, *args, **kwargs):
140        return self._name
141
142
143class CircuitBreakerError(Exception):
144    def __init__(self, circuit_breaker, *args, **kwargs):
145        """
146        :param circuit_breaker:
147        :param args:
148        :param kwargs:
149        :return:
150        """
151        super(CircuitBreakerError, self).__init__(*args, **kwargs)
152        self._circuit_breaker = circuit_breaker
153
154    def __str__(self, *args, **kwargs):
155        return 'Circuit "%s" OPEN until %s (%d failures, %d sec remaining) (last_failure: %r)' % (
156            self._circuit_breaker.name,
157            self._circuit_breaker.open_until,
158            self._circuit_breaker.failure_count,
159            round(self._circuit_breaker.open_remaining),
160            self._circuit_breaker.last_failure,
161        )
162
163
164class CircuitBreakerMonitor(object):
165    circuit_breakers = {}
166
167    @classmethod
168    def register(cls, circuit_breaker):
169        cls.circuit_breakers[circuit_breaker.name] = circuit_breaker
170
171    @classmethod
172    def all_closed(cls):
173        # type: () -> bool
174        return len(list(cls.get_open())) == 0
175
176    @classmethod
177    def get_circuits(cls):
178        # type: () -> Iterable[CircuitBreaker]
179        return cls.circuit_breakers.values()
180
181    @classmethod
182    def get(cls, name):
183        # type: (AnyStr) -> CircuitBreaker
184        return cls.circuit_breakers.get(name)
185
186    @classmethod
187    def get_open(cls):
188        # type: () -> Iterable[CircuitBreaker]
189        for circuit in cls.get_circuits():
190            if circuit.opened:
191                yield circuit
192
193    @classmethod
194    def get_closed(cls):
195        # type: () -> Iterable[CircuitBreaker]
196        for circuit in cls.get_circuits():
197            if circuit.closed:
198                yield circuit
199
200
201def circuit(failure_threshold=None,
202            recovery_timeout=None,
203            expected_exception=None,
204            name=None,
205            fallback_function=None,
206            cls=CircuitBreaker):
207
208    # if the decorator is used without parameters, the
209    # wrapped function is provided as first argument
210    if callable(failure_threshold):
211        return cls().decorate(failure_threshold)
212    else:
213        return cls(
214            failure_threshold=failure_threshold,
215            recovery_timeout=recovery_timeout,
216            expected_exception=expected_exception,
217            name=name,
218            fallback_function=fallback_function)
219