1# ------------------------------------
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4# ------------------------------------
5"""Tests for the retry policy."""
6try:
7    from io import BytesIO
8except ImportError:
9    from cStringIO import StringIO as BytesIO
10import pytest
11from itertools import product
12from azure.core.configuration import ConnectionConfiguration
13from azure.core.exceptions import (
14    AzureError,
15    ServiceRequestError,
16    ServiceRequestTimeoutError,
17    ServiceResponseError,
18    ServiceResponseTimeoutError,
19)
20from azure.core.pipeline.policies import (
21    RetryPolicy,
22    RetryMode,
23)
24from azure.core.pipeline import Pipeline, PipelineResponse
25from azure.core.pipeline.transport import (
26    HttpTransport,
27)
28import tempfile
29import os
30import time
31
32try:
33    from unittest.mock import Mock
34except ImportError:
35    from mock import Mock
36from utils import HTTP_REQUESTS, request_and_responses_product, HTTP_RESPONSES, create_http_response
37
38
39def test_retry_code_class_variables():
40    retry_policy = RetryPolicy()
41    assert retry_policy._RETRY_CODES is not None
42    assert 408 in retry_policy._RETRY_CODES
43    assert 429 in retry_policy._RETRY_CODES
44    assert 501 not in retry_policy._RETRY_CODES
45
46def test_retry_types():
47    history = ["1", "2", "3"]
48    settings = {
49        'history': history,
50        'backoff': 1,
51        'max_backoff': 10
52    }
53    retry_policy = RetryPolicy()
54    backoff_time = retry_policy.get_backoff_time(settings)
55    assert backoff_time == 4
56
57    retry_policy = RetryPolicy(retry_mode=RetryMode.Fixed)
58    backoff_time = retry_policy.get_backoff_time(settings)
59    assert backoff_time == 1
60
61    retry_policy = RetryPolicy(retry_mode=RetryMode.Exponential)
62    backoff_time = retry_policy.get_backoff_time(settings)
63    assert backoff_time == 4
64
65@pytest.mark.parametrize("retry_after_input,http_request,http_response", product(['0', '800', '1000', '1200'], HTTP_REQUESTS, HTTP_RESPONSES))
66def test_retry_after(retry_after_input, http_request, http_response):
67    retry_policy = RetryPolicy()
68    request = http_request("GET", "http://localhost")
69    response = create_http_response(http_response, request, None)
70    response.headers["retry-after-ms"] = retry_after_input
71    pipeline_response = PipelineResponse(request, response, None)
72    retry_after = retry_policy.get_retry_after(pipeline_response)
73    seconds = float(retry_after_input)
74    assert retry_after == seconds/1000.0
75    response.headers.pop("retry-after-ms")
76    response.headers["Retry-After"] = retry_after_input
77    retry_after = retry_policy.get_retry_after(pipeline_response)
78    assert retry_after == float(retry_after_input)
79    response.headers["retry-after-ms"] = 500
80    retry_after = retry_policy.get_retry_after(pipeline_response)
81    assert retry_after == float(retry_after_input)
82
83@pytest.mark.parametrize("retry_after_input,http_request,http_response", product(['0', '800', '1000', '1200'], HTTP_REQUESTS, HTTP_RESPONSES))
84def test_x_ms_retry_after(retry_after_input, http_request, http_response):
85    retry_policy = RetryPolicy()
86    request = http_request("GET", "http://localhost")
87    response = create_http_response(http_response, request, None)
88    response.headers["x-ms-retry-after-ms"] = retry_after_input
89    pipeline_response = PipelineResponse(request, response, None)
90    retry_after = retry_policy.get_retry_after(pipeline_response)
91    seconds = float(retry_after_input)
92    assert retry_after == seconds/1000.0
93    response.headers.pop("x-ms-retry-after-ms")
94    response.headers["Retry-After"] = retry_after_input
95    retry_after = retry_policy.get_retry_after(pipeline_response)
96    assert retry_after == float(retry_after_input)
97    response.headers["x-ms-retry-after-ms"] = 500
98    retry_after = retry_policy.get_retry_after(pipeline_response)
99    assert retry_after == float(retry_after_input)
100
101@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES))
102def test_retry_on_429(http_request, http_response):
103    class MockTransport(HttpTransport):
104        def __init__(self):
105            self._count = 0
106        def __exit__(self, exc_type, exc_val, exc_tb):
107            pass
108        def close(self):
109            pass
110        def open(self):
111            pass
112
113        def send(self, request, **kwargs):  # type: (PipelineRequest, Any) -> PipelineResponse
114            self._count += 1
115            response = create_http_response(http_response, request, None)
116            response.status_code = 429
117            return response
118
119    http_request = http_request('GET', 'http://localhost/')
120    http_retry = RetryPolicy(retry_total = 1)
121    transport = MockTransport()
122    pipeline = Pipeline(transport, [http_retry])
123    pipeline.run(http_request)
124    assert transport._count == 2
125
126@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES))
127def test_no_retry_on_201(http_request, http_response):
128    class MockTransport(HttpTransport):
129        def __init__(self):
130            self._count = 0
131        def __exit__(self, exc_type, exc_val, exc_tb):
132            pass
133        def close(self):
134            pass
135        def open(self):
136            pass
137
138        def send(self, request, **kwargs):  # type: (PipelineRequest, Any) -> PipelineResponse
139            self._count += 1
140            response = create_http_response(http_response, request, None)
141            response.status_code = 201
142            headers = {"Retry-After": "1"}
143            response.headers = headers
144            return response
145
146    http_request = http_request('GET', 'http://localhost/')
147    http_retry = RetryPolicy(retry_total = 1)
148    transport = MockTransport()
149    pipeline = Pipeline(transport, [http_retry])
150    pipeline.run(http_request)
151    assert transport._count == 1
152
153@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES))
154def test_retry_seekable_stream(http_request, http_response):
155    class MockTransport(HttpTransport):
156        def __init__(self):
157            self._first = True
158        def __exit__(self, exc_type, exc_val, exc_tb):
159            pass
160        def close(self):
161            pass
162        def open(self):
163            pass
164
165        def send(self, request, **kwargs):  # type: (PipelineRequest, Any) -> PipelineResponse
166            if self._first:
167                self._first = False
168                request.body.seek(0,2)
169                raise AzureError('fail on first')
170            position = request.body.tell()
171            assert position == 0
172            response = create_http_response(http_response, request, None)
173            response.status_code = 400
174            return response
175
176    data = BytesIO(b"Lots of dataaaa")
177    http_request = http_request('GET', 'http://localhost/')
178    http_request.set_streamed_data_body(data)
179    http_retry = RetryPolicy(retry_total = 1)
180    pipeline = Pipeline(MockTransport(), [http_retry])
181    pipeline.run(http_request)
182
183@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES))
184def test_retry_seekable_file(http_request, http_response):
185    class MockTransport(HttpTransport):
186        def __init__(self):
187            self._first = True
188        def __exit__(self, exc_type, exc_val, exc_tb):
189            pass
190        def close(self):
191            pass
192        def open(self):
193            pass
194
195        def send(self, request, **kwargs):  # type: (PipelineRequest, Any) -> PipelineResponse
196            if self._first:
197                self._first = False
198                for value in request.files.values():
199                    name, body = value[0], value[1]
200                    if name and body and hasattr(body, 'read'):
201                        body.seek(0,2)
202                        raise AzureError('fail on first')
203            for value in request.files.values():
204                name, body = value[0], value[1]
205                if name and body and hasattr(body, 'read'):
206                    position = body.tell()
207                    assert not position
208                    response = create_http_response(http_response, request, None)
209                    response.status_code = 400
210                    return response
211
212    file = tempfile.NamedTemporaryFile(delete=False)
213    file.write(b'Lots of dataaaa')
214    file.close()
215    http_request = http_request('GET', 'http://localhost/')
216    headers = {'Content-Type': "multipart/form-data"}
217    http_request.headers = headers
218    with open(file.name, 'rb') as f:
219        form_data_content = {
220            'fileContent': f,
221            'fileName': f.name,
222        }
223        http_request.set_formdata_body(form_data_content)
224        http_retry = RetryPolicy(retry_total=1)
225        pipeline = Pipeline(MockTransport(), [http_retry])
226        pipeline.run(http_request)
227    os.unlink(f.name)
228
229@pytest.mark.parametrize("http_request", HTTP_REQUESTS)
230def test_retry_timeout(http_request):
231    timeout = 1
232
233    def send(request, **kwargs):
234        assert kwargs["connection_timeout"] <= timeout, "policy should set connection_timeout not to exceed timeout"
235        raise ServiceResponseError("oops")
236
237    transport = Mock(
238        spec=HttpTransport,
239        send=Mock(wraps=send),
240        connection_config=ConnectionConfiguration(connection_timeout=timeout * 2),
241        sleep=time.sleep,
242    )
243    pipeline = Pipeline(transport, [RetryPolicy(timeout=timeout)])
244
245    with pytest.raises(ServiceResponseTimeoutError):
246        response = pipeline.run(http_request("GET", "http://localhost/"))
247
248@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES))
249def test_timeout_defaults(http_request, http_response):
250    """When "timeout" is not set, the policy should not override the transport's timeout configuration"""
251
252    def send(request, **kwargs):
253        for arg in ("connection_timeout", "read_timeout"):
254            assert arg not in kwargs, "policy should defer to transport configuration when not given a timeout"
255        response = create_http_response(http_response, request, None)
256        response.status_code = 200
257        return response
258
259    transport = Mock(
260        spec_set=HttpTransport,
261        send=Mock(wraps=send),
262        sleep=Mock(side_effect=Exception("policy should not sleep: its first send succeeded")),
263    )
264    pipeline = Pipeline(transport, [RetryPolicy()])
265
266    pipeline.run(http_request("GET", "http://localhost/"))
267    assert transport.send.call_count == 1, "policy should not retry: its first send succeeded"
268
269combinations = [(ServiceRequestError, ServiceRequestTimeoutError), (ServiceResponseError, ServiceResponseTimeoutError)]
270
271@pytest.mark.parametrize(
272    "combinations,http_request",
273    product(combinations, HTTP_REQUESTS),
274)
275def test_does_not_sleep_after_timeout(combinations, http_request):
276    # With default settings policy will sleep twice before exhausting its retries: 1.6s, 3.2s.
277    # It should not sleep the second time when given timeout=1
278    transport_error,expected_timeout_error = combinations
279    timeout = 1
280
281    transport = Mock(
282        spec=HttpTransport,
283        send=Mock(side_effect=transport_error("oops")),
284        sleep=Mock(wraps=time.sleep),
285    )
286    pipeline = Pipeline(transport, [RetryPolicy(timeout=timeout)])
287
288    with pytest.raises(expected_timeout_error):
289        pipeline.run(http_request("GET", "http://localhost/"))
290
291    assert transport.sleep.call_count == 1
292