1import ssl
2import sys
3from types import TracebackType
4from typing import Iterable, Iterator, List, Optional, Type
5
6from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol
7from .._models import Origin, Request, Response
8from .._ssl import default_ssl_context
9from .._synchronization import Event, Lock
10from ..backends.sync import SyncBackend
11from ..backends.base import NetworkBackend
12from .connection import HTTPConnection
13from .interfaces import ConnectionInterface, RequestInterface
14
15
16class RequestStatus:
17    def __init__(self, request: Request):
18        self.request = request
19        self.connection: Optional[ConnectionInterface] = None
20        self._connection_acquired = Event()
21
22    def set_connection(self, connection: ConnectionInterface) -> None:
23        assert self.connection is None
24        self.connection = connection
25        self._connection_acquired.set()
26
27    def unset_connection(self) -> None:
28        assert self.connection is not None
29        self.connection = None
30        self._connection_acquired = Event()
31
32    def wait_for_connection(
33        self, timeout: float = None
34    ) -> ConnectionInterface:
35        self._connection_acquired.wait(timeout=timeout)
36        assert self.connection is not None
37        return self.connection
38
39
40class ConnectionPool(RequestInterface):
41    """
42    A connection pool for making HTTP requests.
43    """
44
45    def __init__(
46        self,
47        ssl_context: ssl.SSLContext = None,
48        max_connections: Optional[int] = 10,
49        max_keepalive_connections: int = None,
50        keepalive_expiry: float = None,
51        http1: bool = True,
52        http2: bool = False,
53        retries: int = 0,
54        local_address: str = None,
55        uds: str = None,
56        network_backend: NetworkBackend = None,
57    ) -> None:
58        """
59        A connection pool for making HTTP requests.
60
61        Parameters:
62            ssl_context: An SSL context to use for verifying connections.
63                If not specified, the default `httpcore.default_ssl_context()`
64                will be used.
65            max_connections: The maximum number of concurrent HTTP connections that
66                the pool should allow. Any attempt to send a request on a pool that
67                would exceed this amount will block until a connection is available.
68            max_keepalive_connections: The maximum number of idle HTTP connections
69                that will be maintained in the pool.
70            keepalive_expiry: The duration in seconds that an idle HTTP connection
71                may be maintained for before being expired from the pool.
72            http1: A boolean indicating if HTTP/1.1 requests should be supported
73                by the connection pool. Defaults to True.
74            http2: A boolean indicating if HTTP/2 requests should be supported by
75                the connection pool. Defaults to False.
76            retries: The maximum number of retries when trying to establish a
77                connection.
78            local_address: Local address to connect from. Can also be used to connect
79                using a particular address family. Using `local_address="0.0.0.0"`
80                will connect using an `AF_INET` address (IPv4), while using
81                `local_address="::"` will connect using an `AF_INET6` address (IPv6).
82            uds: Path to a Unix Domain Socket to use instead of TCP sockets.
83            network_backend: A backend instance to use for handling network I/O.
84        """
85        if ssl_context is None:
86            ssl_context = default_ssl_context()
87
88        self._ssl_context = ssl_context
89
90        self._max_connections = (
91            sys.maxsize if max_connections is None else max_connections
92        )
93        self._max_keepalive_connections = (
94            sys.maxsize
95            if max_keepalive_connections is None
96            else max_keepalive_connections
97        )
98        self._max_keepalive_connections = min(
99            self._max_connections, self._max_keepalive_connections
100        )
101
102        self._keepalive_expiry = keepalive_expiry
103        self._http1 = http1
104        self._http2 = http2
105        self._retries = retries
106        self._local_address = local_address
107        self._uds = uds
108
109        self._pool: List[ConnectionInterface] = []
110        self._requests: List[RequestStatus] = []
111        self._pool_lock = Lock()
112        self._network_backend = (
113            SyncBackend() if network_backend is None else network_backend
114        )
115
116    def create_connection(self, origin: Origin) -> ConnectionInterface:
117        return HTTPConnection(
118            origin=origin,
119            ssl_context=self._ssl_context,
120            keepalive_expiry=self._keepalive_expiry,
121            http1=self._http1,
122            http2=self._http2,
123            retries=self._retries,
124            local_address=self._local_address,
125            uds=self._uds,
126            network_backend=self._network_backend,
127        )
128
129    @property
130    def connections(self) -> List[ConnectionInterface]:
131        """
132        Return a list of the connections currently in the pool.
133
134        For example:
135
136        ```python
137        >>> pool.connections
138        [
139            <HTTPConnection ['https://example.com:443', HTTP/1.1, ACTIVE, Request Count: 6]>,
140            <HTTPConnection ['https://example.com:443', HTTP/1.1, IDLE, Request Count: 9]> ,
141            <HTTPConnection ['http://example.com:80', HTTP/1.1, IDLE, Request Count: 1]>,
142        ]
143        ```
144        """
145        return list(self._pool)
146
147    def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool:
148        """
149        Attempt to provide a connection that can handle the given origin.
150        """
151        origin = status.request.url.origin
152
153        # If there are queued requests in front of us, then don't acquire a
154        # connection. We handle requests strictly in order.
155        waiting = [s for s in self._requests if s.connection is None]
156        if waiting and waiting[0] is not status:
157            return False
158
159        # Reuse an existing connection if one is currently available.
160        for idx, connection in enumerate(self._pool):
161            if connection.can_handle_request(origin) and connection.is_available():
162                self._pool.pop(idx)
163                self._pool.insert(0, connection)
164                status.set_connection(connection)
165                return True
166
167        # If the pool is currently full, attempt to close one idle connection.
168        if len(self._pool) >= self._max_connections:
169            for idx, connection in reversed(list(enumerate(self._pool))):
170                if connection.is_idle():
171                    connection.close()
172                    self._pool.pop(idx)
173                    break
174
175        # If the pool is still full, then we cannot acquire a connection.
176        if len(self._pool) >= self._max_connections:
177            return False
178
179        # Otherwise create a new connection.
180        connection = self.create_connection(origin)
181        self._pool.insert(0, connection)
182        status.set_connection(connection)
183        return True
184
185    def _close_expired_connections(self) -> None:
186        """
187        Clean up the connection pool by closing off any connections that have expired.
188        """
189        # Close any connections that have expired their keep-alive time.
190        for idx, connection in reversed(list(enumerate(self._pool))):
191            if connection.has_expired():
192                connection.close()
193                self._pool.pop(idx)
194
195        # If the pool size exceeds the maximum number of allowed keep-alive connections,
196        # then close off idle connections as required.
197        pool_size = len(self._pool)
198        for idx, connection in reversed(list(enumerate(self._pool))):
199            if connection.is_idle() and pool_size > self._max_keepalive_connections:
200                connection.close()
201                self._pool.pop(idx)
202                pool_size -= 1
203
204    def handle_request(self, request: Request) -> Response:
205        """
206        Send an HTTP request, and return an HTTP response.
207
208        This is the core implementation that is called into by `.request()` or `.stream()`.
209        """
210        scheme = request.url.scheme.decode()
211        if scheme == "":
212            raise UnsupportedProtocol(
213                "Request URL is missing an 'http://' or 'https://' protocol."
214            )
215        if scheme not in ("http", "https"):
216            raise UnsupportedProtocol(
217                f"Request URL has an unsupported protocol '{scheme}://'."
218            )
219
220        status = RequestStatus(request)
221
222        with self._pool_lock:
223            self._requests.append(status)
224            self._close_expired_connections()
225            self._attempt_to_acquire_connection(status)
226
227        while True:
228            timeouts = request.extensions.get("timeout", {})
229            timeout = timeouts.get("pool", None)
230            connection = status.wait_for_connection(timeout=timeout)
231            try:
232                response = connection.handle_request(request)
233            except ConnectionNotAvailable:
234                # The ConnectionNotAvailable exception is a special case, that
235                # indicates we need to retry the request on a new connection.
236                #
237                # The most common case where this can occur is when multiple
238                # requests are queued waiting for a single connection, which
239                # might end up as an HTTP/2 connection, but which actually ends
240                # up as HTTP/1.1.
241                with self._pool_lock:
242                    # Maintain our position in the request queue, but reset the
243                    # status so that the request becomes queued again.
244                    status.unset_connection()
245                    self._attempt_to_acquire_connection(status)
246            except Exception as exc:
247                self.response_closed(status)
248                raise exc
249            else:
250                break
251
252        # When we return the response, we wrap the stream in a special class
253        # that handles notifying the connection pool once the response
254        # has been released.
255        assert isinstance(response.stream, Iterable)
256        return Response(
257            status=response.status,
258            headers=response.headers,
259            content=ConnectionPoolByteStream(response.stream, self, status),
260            extensions=response.extensions,
261        )
262
263    def response_closed(self, status: RequestStatus) -> None:
264        """
265        This method acts as a callback once the request/response cycle is complete.
266
267        It is called into from the `ConnectionPoolByteStream.close()` method.
268        """
269        assert status.connection is not None
270        connection = status.connection
271
272        with self._pool_lock:
273            # Update the state of the connection pool.
274            self._requests.remove(status)
275
276            if connection.is_closed() and connection in self._pool:
277                self._pool.remove(connection)
278
279            # Since we've had a response closed, it's possible we'll now be able
280            # to service one or more requests that are currently pending.
281            for status in self._requests:
282                if status.connection is None:
283                    acquired = self._attempt_to_acquire_connection(status)
284                    # If we could not acquire a connection for a queued request
285                    # then we don't need to check anymore requests that are
286                    # queued later behind it.
287                    if not acquired:
288                        break
289
290            # Housekeeping.
291            self._close_expired_connections()
292
293    def close(self) -> None:
294        """
295        Close any connections in the pool.
296        """
297        with self._pool_lock:
298            for connection in self._pool:
299                connection.close()
300            self._pool = []
301            self._requests = []
302
303    def __enter__(self) -> "ConnectionPool":
304        return self
305
306    def __exit__(
307        self,
308        exc_type: Type[BaseException] = None,
309        exc_value: BaseException = None,
310        traceback: TracebackType = None,
311    ) -> None:
312        self.close()
313
314
315class ConnectionPoolByteStream:
316    """
317    A wrapper around the response byte stream, that additionally handles
318    notifying the connection pool when the response has been closed.
319    """
320
321    def __init__(
322        self,
323        stream: Iterable[bytes],
324        pool: ConnectionPool,
325        status: RequestStatus,
326    ) -> None:
327        self._stream = stream
328        self._pool = pool
329        self._status = status
330
331    def __iter__(self) -> Iterator[bytes]:
332        for part in self._stream:
333            yield part
334
335    def close(self) -> None:
336        try:
337            if hasattr(self._stream, "close"):
338                self._stream.close()  # type: ignore
339        finally:
340            self._pool.response_closed(self._status)
341