1# ------------------------------------ 2# Copyright (c) Microsoft Corporation. 3# Licensed under the MIT License. 4# ------------------------------------ 5import requests 6from azure.core.pipeline.transport import ( 7 AsyncHttpTransport, 8 AsyncioRequestsTransportResponse, 9 AioHttpTransport, 10) 11from azure.core.pipeline import AsyncPipeline, PipelineResponse 12from azure.core.pipeline.transport._aiohttp import AioHttpStreamDownloadGenerator 13from unittest import mock 14import pytest 15from utils import request_and_responses_product, ASYNC_HTTP_RESPONSES, create_http_response 16 17@pytest.mark.asyncio 18@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(ASYNC_HTTP_RESPONSES)) 19async def test_connection_error_response(http_request, http_response): 20 class MockSession(object): 21 def __init__(self): 22 self.auto_decompress = True 23 24 @property 25 def auto_decompress(self): 26 return self.auto_decompress 27 28 class MockTransport(AsyncHttpTransport): 29 def __init__(self): 30 self._count = 0 31 self.session = MockSession 32 33 async def __aexit__(self, exc_type, exc_val, exc_tb): 34 pass 35 async def close(self): 36 pass 37 async def open(self): 38 pass 39 40 async def send(self, request, **kwargs): 41 request = http_request('GET', 'http://localhost/') 42 response = create_http_response(http_response, request, None) 43 response.status_code = 200 44 return response 45 46 class MockContent(): 47 def __init__(self): 48 self._first = True 49 50 async def read(self, block_size): 51 if self._first: 52 self._first = False 53 raise ConnectionError 54 return None 55 56 class MockInternalResponse(): 57 def __init__(self): 58 self.headers = {} 59 self.content = MockContent() 60 61 async def close(self): 62 pass 63 64 class AsyncMock(mock.MagicMock): 65 async def __call__(self, *args, **kwargs): 66 return super(AsyncMock, self).__call__(*args, **kwargs) 67 68 http_request = http_request('GET', 'http://localhost/') 69 pipeline = AsyncPipeline(MockTransport()) 70 http_response = create_http_response(http_response, http_request, None) 71 http_response.internal_response = MockInternalResponse() 72 stream = AioHttpStreamDownloadGenerator(pipeline, http_response, decompress=False) 73 with mock.patch('asyncio.sleep', new_callable=AsyncMock): 74 with pytest.raises(ConnectionError): 75 await stream.__anext__() 76 77@pytest.mark.asyncio 78@pytest.mark.parametrize("http_response", ASYNC_HTTP_RESPONSES) 79async def test_response_streaming_error_behavior(http_response): 80 # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723 81 block_size = 103 82 total_response_size = 500 83 req_response = requests.Response() 84 req_request = requests.Request() 85 86 class FakeStreamWithConnectionError: 87 # fake object for urllib3.response.HTTPResponse 88 def __init__(self): 89 self.total_response_size = 500 90 91 def stream(self, chunk_size, decode_content=False): 92 assert chunk_size == block_size 93 left = total_response_size 94 while left > 0: 95 if left <= block_size: 96 raise requests.exceptions.ConnectionError() 97 data = b"X" * min(chunk_size, left) 98 left -= len(data) 99 yield data 100 101 def read(self, chunk_size, decode_content=False): 102 assert chunk_size == block_size 103 if self.total_response_size > 0: 104 if self.total_response_size <= block_size: 105 raise requests.exceptions.ConnectionError() 106 data = b"X" * min(chunk_size, self.total_response_size) 107 self.total_response_size -= len(data) 108 return data 109 110 def close(self): 111 pass 112 113 req_response.raw = FakeStreamWithConnectionError() 114 115 response = AsyncioRequestsTransportResponse( 116 req_request, 117 req_response, 118 block_size, 119 ) 120 121 async def mock_run(self, *args, **kwargs): 122 return PipelineResponse( 123 None, 124 requests.Response(), 125 None, 126 ) 127 128 transport = AioHttpTransport() 129 pipeline = AsyncPipeline(transport) 130 pipeline.run = mock_run 131 downloader = response.stream_download(pipeline) 132 with pytest.raises(requests.exceptions.ConnectionError): 133 while True: 134 await downloader.__anext__() 135