1# ------------------------------------
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4# ------------------------------------
5import requests
6from azure.core.pipeline.transport import (
7    AsyncHttpTransport,
8    AsyncioRequestsTransportResponse,
9    AioHttpTransport,
10)
11from azure.core.pipeline import AsyncPipeline, PipelineResponse
12from azure.core.pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator
13from unittest import mock
14import pytest
15from utils import request_and_responses_product, ASYNC_HTTP_RESPONSES, create_http_response
16
17@pytest.mark.asyncio
18@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNC_HTTP_RESPONSES))
19async def test_connection_error_response(http_request, http_response):
20    class MockSession(object):
21        def __init__(self):
22            self.auto_decompress = True
23
24        @property
25        def auto_decompress(self):
26            return self.auto_decompress
27
28    class MockTransport(AsyncHttpTransport):
29        def __init__(self):
30            self._count = 0
31            self.session = MockSession
32
33        async def __aexit__(self, exc_type, exc_val, exc_tb):
34            pass
35        async def close(self):
36            pass
37        async def open(self):
38            pass
39
40        async def send(self, request, **kwargs):
41            request = http_request('GET', 'http://localhost/')
42            response = create_http_response(http_response, request, None)
43            response.status_code = 200
44            return response
45
46    class MockContent():
47        def __init__(self):
48            self._first = True
49
50        async def read(self, block_size):
51            if self._first:
52                self._first = False
53                raise ConnectionError
54            return None
55
56    class MockInternalResponse():
57        def __init__(self):
58            self.headers = {}
59            self.content = MockContent()
60
61        async def close(self):
62            pass
63
64    class AsyncMock(mock.MagicMock):
65        async def __call__(self, *args, **kwargs):
66            return super(AsyncMock, self).__call__(*args, **kwargs)
67
68    http_request = http_request('GET', 'http://localhost/')
69    pipeline = AsyncPipeline(MockTransport())
70    http_response = create_http_response(http_response, http_request, None)
71    http_response.internal_response = MockInternalResponse()
72    stream = AioHttpStreamDownloadGenerator(pipeline, http_response, decompress=False)
73    with mock.patch('asyncio.sleep', new_callable=AsyncMock):
74        with pytest.raises(ConnectionError):
75            await stream.__anext__()
76
77@pytest.mark.asyncio
78@pytest.mark.parametrize("http_response", ASYNC_HTTP_RESPONSES)
79async def test_response_streaming_error_behavior(http_response):
80    # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723
81    block_size = 103
82    total_response_size = 500
83    req_response = requests.Response()
84    req_request = requests.Request()
85
86    class FakeStreamWithConnectionError:
87        # fake object for urllib3.response.HTTPResponse
88        def __init__(self):
89            self.total_response_size = 500
90
91        def stream(self, chunk_size, decode_content=False):
92            assert chunk_size == block_size
93            left = total_response_size
94            while left > 0:
95                if left <= block_size:
96                    raise requests.exceptions.ConnectionError()
97                data = b"X" * min(chunk_size, left)
98                left -= len(data)
99                yield data
100
101        def read(self, chunk_size, decode_content=False):
102            assert chunk_size == block_size
103            if self.total_response_size > 0:
104                if self.total_response_size <= block_size:
105                    raise requests.exceptions.ConnectionError()
106                data = b"X" * min(chunk_size, self.total_response_size)
107                self.total_response_size -= len(data)
108                return data
109
110        def close(self):
111            pass
112
113    req_response.raw = FakeStreamWithConnectionError()
114
115    response = AsyncioRequestsTransportResponse(
116        req_request,
117        req_response,
118        block_size,
119    )
120
121    async def mock_run(self, *args, **kwargs):
122        return PipelineResponse(
123            None,
124            requests.Response(),
125            None,
126        )
127
128    transport = AioHttpTransport()
129    pipeline = AsyncPipeline(transport)
130    pipeline.run = mock_run
131    downloader = response.stream_download(pipeline)
132    with pytest.raises(requests.exceptions.ConnectionError):
133        while True:
134            await downloader.__anext__()
135