1# ------------------------------------
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4# ------------------------------------
5import requests
6from azure.core.pipeline.transport import (
7    HttpTransport,
8    RequestsTransport,
9)
10from azure.core.pipeline import Pipeline, PipelineResponse
11from azure.core.pipeline.transport._requests_basic import StreamDownloadGenerator
12try:
13    from unittest import mock
14except ImportError:
15    import mock
16import pytest
17from utils import HTTP_RESPONSES, REQUESTS_TRANSPORT_RESPONSES, create_http_response, create_transport_response, request_and_responses_product
18
19@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES))
20def test_connection_error_response(http_request, http_response):
21    class MockTransport(HttpTransport):
22        def __init__(self):
23            self._count = 0
24
25        def __exit__(self, exc_type, exc_val, exc_tb):
26            pass
27        def close(self):
28            pass
29        def open(self):
30            pass
31
32        def send(self, request, **kwargs):
33            request = http_request('GET', 'http://localhost/')
34            response = create_http_response(http_response, request, None)
35            response.status_code = 200
36            return response
37
38        def next(self):
39            self.__next__()
40
41        def __next__(self):
42            if self._count == 0:
43                self._count += 1
44                raise requests.exceptions.ConnectionError
45
46        def stream(self, chunk_size, decode_content=False):
47            if self._count == 0:
48                self._count += 1
49                raise requests.exceptions.ConnectionError
50            while True:
51                yield b"test"
52
53    class MockInternalResponse():
54        def __init__(self):
55            self.raw = MockTransport()
56
57        def close(self):
58            pass
59
60    http_request = http_request('GET', 'http://localhost/')
61    pipeline = Pipeline(MockTransport())
62    http_response = create_http_response(http_response, http_request, None)
63    http_response.internal_response = MockInternalResponse()
64    stream = StreamDownloadGenerator(pipeline, http_response, decompress=False)
65    with mock.patch('time.sleep', return_value=None):
66        with pytest.raises(requests.exceptions.ConnectionError):
67            stream.__next__()
68
69@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES)
70def test_response_streaming_error_behavior(http_response):
71    # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723
72    block_size = 103
73    total_response_size = 500
74    req_response = requests.Response()
75    req_request = requests.Request()
76
77    class FakeStreamWithConnectionError:
78        # fake object for urllib3.response.HTTPResponse
79        def __init__(self):
80            self.total_response_size = 500
81
82        def stream(self, chunk_size, decode_content=False):
83            assert chunk_size == block_size
84            left = total_response_size
85            while left > 0:
86                if left <= block_size:
87                    raise requests.exceptions.ConnectionError()
88                data = b"X" * min(chunk_size, left)
89                left -= len(data)
90                yield data
91
92        def read(self, chunk_size, decode_content=False):
93            assert chunk_size == block_size
94            if self.total_response_size > 0:
95                if self.total_response_size <= block_size:
96                    raise requests.exceptions.ConnectionError()
97                data = b"X" * min(chunk_size, self.total_response_size)
98                self.total_response_size -= len(data)
99                return data
100
101        def close(self):
102            pass
103
104    s = FakeStreamWithConnectionError()
105    req_response.raw = FakeStreamWithConnectionError()
106
107    response = create_transport_response(
108        http_response,
109        req_request,
110        req_response,
111        block_size,
112    )
113
114    def mock_run(self, *args, **kwargs):
115        return PipelineResponse(
116            None,
117            requests.Response(),
118            None,
119        )
120
121    transport = RequestsTransport()
122    pipeline = Pipeline(transport)
123    pipeline.run = mock_run
124    downloader = response.stream_download(pipeline, decompress=False)
125    with pytest.raises(requests.exceptions.ConnectionError):
126        full_response = b"".join(downloader)
127