1# -*- coding: utf-8 -*- #
2# Copyright 2015 Google LLC. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#    http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Implementation of retrying logic."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import unicode_literals
21
22import collections
23import functools
24import itertools
25import random
26import sys
27import time
28
29from googlecloudsdk.core import exceptions
30
31
32_DEFAULT_JITTER_MS = 1000
33
34
35class RetryerState(object):
36  """Object that holds the state of the retryer."""
37
38  def __init__(self, retrial, time_passed_ms, time_to_wait_ms):
39    """Initializer for RetryerState.
40
41    Args:
42      retrial: int, the retry attempt we are currently at.
43      time_passed_ms: int, number of ms that passed since we started retryer.
44      time_to_wait_ms: int, number of ms to wait for the until next trial.
45          If this number is -1, it means the iterable item that specifies the
46          next sleep value has raised StopIteration.
47    """
48    self.retrial = retrial
49    self.time_passed_ms = time_passed_ms
50    self.time_to_wait_ms = time_to_wait_ms
51
52
53class RetryException(Exception):
54  """Raised to stop retrials on failure."""
55
56  def __init__(self, message, last_result, state):
57    self.message = message
58    self.last_result = last_result
59    self.state = state
60    super(RetryException, self).__init__(message)
61
62  def __str__(self):
63    return ('last_result={last_result}, last_retrial={last_retrial}, '
64            'time_passed_ms={time_passed_ms},'
65            'time_to_wait={time_to_wait_ms}'.format(
66                last_result=self.last_result,
67                last_retrial=self.state.retrial,
68                time_passed_ms=self.state.time_passed_ms,
69                time_to_wait_ms=self.state.time_to_wait_ms))
70
71
72class WaitException(RetryException):
73  """Raised when timeout was reached."""
74
75
76class MaxRetrialsException(RetryException):
77  """Raised when too many retrials reached."""
78
79
80class Retryer(object):
81  """Retries a function based on specified retry strategy."""
82
83  def __init__(self, max_retrials=None, max_wait_ms=None,
84               exponential_sleep_multiplier=None, jitter_ms=_DEFAULT_JITTER_MS,
85               status_update_func=None, wait_ceiling_ms=None):
86    """Initializer for Retryer.
87
88    Args:
89      max_retrials: int, max number of retrials before raising RetryException.
90      max_wait_ms: int, number of ms to wait before raising
91      exponential_sleep_multiplier: float, The exponential factor to use on
92          subsequent retries.
93      jitter_ms: int, random [0, jitter_ms] additional value to wait.
94      status_update_func: func(result, state) called right after each trial.
95      wait_ceiling_ms: int, maximum wait time between retries, regardless of
96          modifiers added like exponential multiplier or jitter.
97    """
98
99    self._max_retrials = max_retrials
100    self._max_wait_ms = max_wait_ms
101    self._exponential_sleep_multiplier = exponential_sleep_multiplier
102    self._jitter_ms = jitter_ms
103    self._status_update_func = status_update_func
104    self._wait_ceiling_ms = wait_ceiling_ms
105
106  def _RaiseIfStop(self, result, state):
107    if self._max_retrials is not None and self._max_retrials <= state.retrial:
108      raise MaxRetrialsException('Reached', result, state)
109    if self._max_wait_ms is not None:
110      if state.time_passed_ms + state.time_to_wait_ms > self._max_wait_ms:
111        raise WaitException('Timeout', result, state)
112
113  def _GetTimeToWait(self, last_retrial, sleep_ms):
114    """Get time to wait after applying modifyers.
115
116    Apply the exponential sleep multiplyer, jitter and ceiling limiting to the
117    base sleep time.
118
119    Args:
120      last_retrial: int, which retry attempt we just tried. First try this is 0.
121      sleep_ms: int, how long to wait between the current trials.
122
123    Returns:
124      int, ms to wait before trying next attempt with all waiting logic applied.
125    """
126    wait_time_ms = sleep_ms
127    if wait_time_ms:
128      if self._exponential_sleep_multiplier:
129        wait_time_ms *= self._exponential_sleep_multiplier ** last_retrial
130      if self._jitter_ms:
131        wait_time_ms += random.random() * self._jitter_ms
132      if self._wait_ceiling_ms:
133        wait_time_ms = min(wait_time_ms, self._wait_ceiling_ms)
134      return wait_time_ms
135    return 0
136
137  def RetryOnException(self, func, args=None, kwargs=None,
138                       should_retry_if=None, sleep_ms=None):
139    """Retries the function if an exception occurs.
140
141    Args:
142      func: The function to call and retry.
143      args: a sequence of positional arguments to be passed to func.
144      kwargs: a dictionary of positional arguments to be passed to func.
145      should_retry_if: func(exc_type, exc_value, exc_traceback, state) that
146          returns True or False.
147      sleep_ms: int or iterable for how long to wait between trials.
148
149    Returns:
150      Whatever the function returns.
151
152    Raises:
153      RetryException, WaitException: if function is retries too many times,
154        or time limit is reached.
155    """
156
157    args = args if args is not None else ()
158    kwargs = kwargs if kwargs is not None else {}
159
160    def TryFunc():
161      try:
162        return func(*args, **kwargs), None
163      except:  # pylint: disable=bare-except
164        return None, sys.exc_info()
165
166    if should_retry_if is None:
167      should_retry = lambda x, s: x[1] is not None
168    else:
169      def ShouldRetryFunc(try_func_result, state):
170        exc_info = try_func_result[1]
171        if exc_info is None:
172          # No exception, no reason to retry.
173          return False
174        return should_retry_if(exc_info[0], exc_info[1], exc_info[2], state)
175      should_retry = ShouldRetryFunc
176
177    result, exc_info = self.RetryOnResult(
178        TryFunc, should_retry_if=should_retry, sleep_ms=sleep_ms)
179    if exc_info:
180      # Exception that was not retried was raised. Re-raise.
181      exceptions.reraise(exc_info[1], tb=exc_info[2])
182    return result
183
184  def RetryOnResult(self, func, args=None, kwargs=None,
185                    should_retry_if=None, sleep_ms=None):
186    """Retries the function if the given condition is satisfied.
187
188    Args:
189      func: The function to call and retry.
190      args: a sequence of arguments to be passed to func.
191      kwargs: a dictionary of positional arguments to be passed to func.
192      should_retry_if: result to retry on or func(result, RetryerState) that
193          returns True or False if we should retry or not.
194      sleep_ms: int or iterable, for how long to wait between trials.
195
196    Returns:
197      Whatever the function returns.
198
199    Raises:
200      MaxRetrialsException: function retried too many times.
201      WaitException: time limit is reached.
202    """
203    args = args if args is not None else ()
204    kwargs = kwargs if kwargs is not None else {}
205
206    start_time_ms = _GetCurrentTimeMs()
207    retrial = 0
208    if callable(should_retry_if):
209      should_retry = should_retry_if
210    else:
211      should_retry = lambda x, s: x == should_retry_if
212
213    if isinstance(sleep_ms, collections.Iterable):
214      sleep_gen = iter(sleep_ms)
215    else:
216      sleep_gen = itertools.repeat(sleep_ms)
217
218    while True:
219      result = func(*args, **kwargs)
220      time_passed_ms = _GetCurrentTimeMs() - start_time_ms
221      try:
222        sleep_from_gen = next(sleep_gen)
223      except StopIteration:
224        time_to_wait_ms = -1
225      else:
226        time_to_wait_ms = self._GetTimeToWait(retrial, sleep_from_gen)
227      state = RetryerState(retrial, time_passed_ms, time_to_wait_ms)
228
229      if not should_retry(result, state):
230        return result
231
232      if time_to_wait_ms == -1:
233        raise MaxRetrialsException('Sleep iteration stop', result, state)
234      if self._status_update_func:
235        self._status_update_func(result, state)
236      self._RaiseIfStop(result, state)
237      _SleepMs(time_to_wait_ms)
238      retrial += 1
239
240
241def RetryOnException(f=None, max_retrials=None, max_wait_ms=None,
242                     sleep_ms=None, exponential_sleep_multiplier=None,
243                     jitter_ms=_DEFAULT_JITTER_MS,
244                     status_update_func=None,
245                     should_retry_if=None):
246  """A decorator to retry on exceptions.
247
248  Args:
249    f: a function to run possibly multiple times
250    max_retrials: int, max number of retrials before raising RetryException.
251    max_wait_ms: int, number of ms to wait before raising
252    sleep_ms: int or iterable, for how long to wait between trials.
253    exponential_sleep_multiplier: float, The exponential factor to use on
254        subsequent retries.
255    jitter_ms: int, random [0, jitter_ms] additional value to wait.
256    status_update_func: func(result, state) called right after each trail.
257    should_retry_if: func(exc_type, exc_value, exc_traceback, state) that
258        returns True or False.
259
260  Returns:
261    A version of f that is executed potentially multiple times and that
262    yields the first returned value or the last exception raised.
263  """
264
265  if f is None:
266    # Returns a decorator---based on retry Retry with max_retrials,
267    # max_wait_ms, sleep_ms, etc. fixed.
268    return functools.partial(
269        RetryOnException,
270        exponential_sleep_multiplier=exponential_sleep_multiplier,
271        jitter_ms=jitter_ms,
272        max_retrials=max_retrials,
273        max_wait_ms=max_wait_ms,
274        should_retry_if=should_retry_if,
275        sleep_ms=sleep_ms,
276        status_update_func=status_update_func)
277
278  @functools.wraps(f)
279  def DecoratedFunction(*args, **kwargs):
280    retryer = Retryer(
281        max_retrials=max_retrials,
282        max_wait_ms=max_wait_ms,
283        exponential_sleep_multiplier=exponential_sleep_multiplier,
284        jitter_ms=jitter_ms,
285        status_update_func=status_update_func)
286    try:
287      return retryer.RetryOnException(f, args=args, kwargs=kwargs,
288                                      should_retry_if=should_retry_if,
289                                      sleep_ms=sleep_ms)
290    except MaxRetrialsException as mre:
291      to_reraise = mre.last_result[1]
292      exceptions.reraise(to_reraise[1], tb=to_reraise[2])
293
294  return DecoratedFunction
295
296
297def _GetCurrentTimeMs():
298  return int(time.time() * 1000)
299
300
301def _SleepMs(time_to_wait_ms):
302  time.sleep(time_to_wait_ms / 1000.0)
303
304