1# ------------------------------------ 2# Copyright (c) Microsoft Corporation. 3# Licensed under the MIT License. 4# ------------------------------------ 5import requests 6from azure.core.pipeline.transport import ( 7 HttpTransport, 8 RequestsTransport, 9) 10from azure.core.pipeline import Pipeline, PipelineResponse 11from azure.core.pipeline.transport._requests_basic import StreamDownloadGenerator 12try: 13 from unittest import mock 14except ImportError: 15 import mock 16import pytest 17from utils import HTTP_RESPONSES, REQUESTS_TRANSPORT_RESPONSES, create_http_response, create_transport_response, request_and_responses_product 18 19@pytest.mark.parametrize("http_request,http_response", request_and_responses_product(HTTP_RESPONSES)) 20def test_connection_error_response(http_request, http_response): 21 class MockTransport(HttpTransport): 22 def __init__(self): 23 self._count = 0 24 25 def __exit__(self, exc_type, exc_val, exc_tb): 26 pass 27 def close(self): 28 pass 29 def open(self): 30 pass 31 32 def send(self, request, **kwargs): 33 request = http_request('GET', 'http://localhost/') 34 response = create_http_response(http_response, request, None) 35 response.status_code = 200 36 return response 37 38 def next(self): 39 self.__next__() 40 41 def __next__(self): 42 if self._count == 0: 43 self._count += 1 44 raise requests.exceptions.ConnectionError 45 46 def stream(self, chunk_size, decode_content=False): 47 if self._count == 0: 48 self._count += 1 49 raise requests.exceptions.ConnectionError 50 while True: 51 yield b"test" 52 53 class MockInternalResponse(): 54 def __init__(self): 55 self.raw = MockTransport() 56 57 def close(self): 58 pass 59 60 http_request = http_request('GET', 'http://localhost/') 61 pipeline = Pipeline(MockTransport()) 62 http_response = create_http_response(http_response, http_request, None) 63 http_response.internal_response = MockInternalResponse() 64 stream = StreamDownloadGenerator(pipeline, http_response, decompress=False) 65 with mock.patch('time.sleep', return_value=None): 66 with pytest.raises(requests.exceptions.ConnectionError): 67 stream.__next__() 68 69@pytest.mark.parametrize("http_response", REQUESTS_TRANSPORT_RESPONSES) 70def test_response_streaming_error_behavior(http_response): 71 # Test to reproduce https://github.com/Azure/azure-sdk-for-python/issues/16723 72 block_size = 103 73 total_response_size = 500 74 req_response = requests.Response() 75 req_request = requests.Request() 76 77 class FakeStreamWithConnectionError: 78 # fake object for urllib3.response.HTTPResponse 79 def __init__(self): 80 self.total_response_size = 500 81 82 def stream(self, chunk_size, decode_content=False): 83 assert chunk_size == block_size 84 left = total_response_size 85 while left > 0: 86 if left <= block_size: 87 raise requests.exceptions.ConnectionError() 88 data = b"X" * min(chunk_size, left) 89 left -= len(data) 90 yield data 91 92 def read(self, chunk_size, decode_content=False): 93 assert chunk_size == block_size 94 if self.total_response_size > 0: 95 if self.total_response_size <= block_size: 96 raise requests.exceptions.ConnectionError() 97 data = b"X" * min(chunk_size, self.total_response_size) 98 self.total_response_size -= len(data) 99 return data 100 101 def close(self): 102 pass 103 104 s = FakeStreamWithConnectionError() 105 req_response.raw = FakeStreamWithConnectionError() 106 107 response = create_transport_response( 108 http_response, 109 req_request, 110 req_response, 111 block_size, 112 ) 113 114 def mock_run(self, *args, **kwargs): 115 return PipelineResponse( 116 None, 117 requests.Response(), 118 None, 119 ) 120 121 transport = RequestsTransport() 122 pipeline = Pipeline(transport) 123 pipeline.run = mock_run 124 downloader = response.stream_download(pipeline, decompress=False) 125 with pytest.raises(requests.exceptions.ConnectionError): 126 full_response = b"".join(downloader) 127