1import socket
2
3import pytest
4
5from urllib3.poolmanager import PoolKey, key_fn_by_scheme, PoolManager
6from urllib3 import connection_from_url
7from urllib3.exceptions import ClosedPoolError, LocationValueError
8from urllib3.util import retry, timeout
9from test import resolvesLocalhostFQDN
10
11
12class TestPoolManager(object):
13    @resolvesLocalhostFQDN
14    def test_same_url(self):
15        # Convince ourselves that normally we don't get the same object
16        conn1 = connection_from_url("http://localhost:8081/foo")
17        conn2 = connection_from_url("http://localhost:8081/bar")
18
19        assert conn1 != conn2
20
21        # Now try again using the PoolManager
22        p = PoolManager(1)
23
24        conn1 = p.connection_from_url("http://localhost:8081/foo")
25        conn2 = p.connection_from_url("http://localhost:8081/bar")
26
27        assert conn1 == conn2
28
29        # Ensure that FQDNs are handled separately from relative domains
30        p = PoolManager(2)
31
32        conn1 = p.connection_from_url("http://localhost.:8081/foo")
33        conn2 = p.connection_from_url("http://localhost:8081/bar")
34
35        assert conn1 != conn2
36
37    def test_many_urls(self):
38        urls = [
39            "http://localhost:8081/foo",
40            "http://www.google.com/mail",
41            "http://localhost:8081/bar",
42            "https://www.google.com/",
43            "https://www.google.com/mail",
44            "http://yahoo.com",
45            "http://bing.com",
46            "http://yahoo.com/",
47        ]
48
49        connections = set()
50
51        p = PoolManager(10)
52
53        for url in urls:
54            conn = p.connection_from_url(url)
55            connections.add(conn)
56
57        assert len(connections) == 5
58
59    def test_manager_clear(self):
60        p = PoolManager(5)
61
62        conn_pool = p.connection_from_url("http://google.com")
63        assert len(p.pools) == 1
64
65        conn = conn_pool._get_conn()
66
67        p.clear()
68        assert len(p.pools) == 0
69
70        with pytest.raises(ClosedPoolError):
71            conn_pool._get_conn()
72
73        conn_pool._put_conn(conn)
74
75        with pytest.raises(ClosedPoolError):
76            conn_pool._get_conn()
77
78        assert len(p.pools) == 0
79
80    @pytest.mark.parametrize("url", ["http://@", None])
81    def test_nohost(self, url):
82        p = PoolManager(5)
83        with pytest.raises(LocationValueError):
84            p.connection_from_url(url=url)
85
86    def test_contextmanager(self):
87        with PoolManager(1) as p:
88            conn_pool = p.connection_from_url("http://google.com")
89            assert len(p.pools) == 1
90            conn = conn_pool._get_conn()
91
92        assert len(p.pools) == 0
93
94        with pytest.raises(ClosedPoolError):
95            conn_pool._get_conn()
96
97        conn_pool._put_conn(conn)
98
99        with pytest.raises(ClosedPoolError):
100            conn_pool._get_conn()
101
102        assert len(p.pools) == 0
103
104    def test_http_pool_key_fields(self):
105        """Assert the HTTPPoolKey fields are honored when selecting a pool."""
106        connection_pool_kw = {
107            "timeout": timeout.Timeout(3.14),
108            "retries": retry.Retry(total=6, connect=2),
109            "block": True,
110            "strict": True,
111            "source_address": "127.0.0.1",
112        }
113        p = PoolManager()
114        conn_pools = [
115            p.connection_from_url("http://example.com/"),
116            p.connection_from_url("http://example.com:8000/"),
117            p.connection_from_url("http://other.example.com/"),
118        ]
119
120        for key, value in connection_pool_kw.items():
121            p.connection_pool_kw[key] = value
122            conn_pools.append(p.connection_from_url("http://example.com/"))
123
124        assert all(
125            x is not y
126            for i, x in enumerate(conn_pools)
127            for j, y in enumerate(conn_pools)
128            if i != j
129        )
130        assert all(isinstance(key, PoolKey) for key in p.pools.keys())
131
132    def test_https_pool_key_fields(self):
133        """Assert the HTTPSPoolKey fields are honored when selecting a pool."""
134        connection_pool_kw = {
135            "timeout": timeout.Timeout(3.14),
136            "retries": retry.Retry(total=6, connect=2),
137            "block": True,
138            "strict": True,
139            "source_address": "127.0.0.1",
140            "key_file": "/root/totally_legit.key",
141            "cert_file": "/root/totally_legit.crt",
142            "cert_reqs": "CERT_REQUIRED",
143            "ca_certs": "/root/path_to_pem",
144            "ssl_version": "SSLv23_METHOD",
145        }
146        p = PoolManager()
147        conn_pools = [
148            p.connection_from_url("https://example.com/"),
149            p.connection_from_url("https://example.com:4333/"),
150            p.connection_from_url("https://other.example.com/"),
151        ]
152        # Asking for a connection pool with the same key should give us an
153        # existing pool.
154        dup_pools = []
155
156        for key, value in connection_pool_kw.items():
157            p.connection_pool_kw[key] = value
158            conn_pools.append(p.connection_from_url("https://example.com/"))
159            dup_pools.append(p.connection_from_url("https://example.com/"))
160
161        assert all(
162            x is not y
163            for i, x in enumerate(conn_pools)
164            for j, y in enumerate(conn_pools)
165            if i != j
166        )
167        assert all(pool in conn_pools for pool in dup_pools)
168        assert all(isinstance(key, PoolKey) for key in p.pools.keys())
169
170    def test_default_pool_key_funcs_copy(self):
171        """Assert each PoolManager gets a copy of ``pool_keys_by_scheme``."""
172        p = PoolManager()
173        assert p.key_fn_by_scheme == p.key_fn_by_scheme
174        assert p.key_fn_by_scheme is not key_fn_by_scheme
175
176    def test_pools_keyed_with_from_host(self):
177        """Assert pools are still keyed correctly with connection_from_host."""
178        ssl_kw = {
179            "key_file": "/root/totally_legit.key",
180            "cert_file": "/root/totally_legit.crt",
181            "cert_reqs": "CERT_REQUIRED",
182            "ca_certs": "/root/path_to_pem",
183            "ssl_version": "SSLv23_METHOD",
184        }
185        p = PoolManager(5, **ssl_kw)
186        conns = [p.connection_from_host("example.com", 443, scheme="https")]
187
188        for k in ssl_kw:
189            p.connection_pool_kw[k] = "newval"
190            conns.append(p.connection_from_host("example.com", 443, scheme="https"))
191
192        assert all(
193            x is not y
194            for i, x in enumerate(conns)
195            for j, y in enumerate(conns)
196            if i != j
197        )
198
199    def test_https_connection_from_url_case_insensitive(self):
200        """Assert scheme case is ignored when pooling HTTPS connections."""
201        p = PoolManager()
202        pool = p.connection_from_url("https://example.com/")
203        other_pool = p.connection_from_url("HTTPS://EXAMPLE.COM/")
204
205        assert 1 == len(p.pools)
206        assert pool is other_pool
207        assert all(isinstance(key, PoolKey) for key in p.pools.keys())
208
209    def test_https_connection_from_host_case_insensitive(self):
210        """Assert scheme case is ignored when getting the https key class."""
211        p = PoolManager()
212        pool = p.connection_from_host("example.com", scheme="https")
213        other_pool = p.connection_from_host("EXAMPLE.COM", scheme="HTTPS")
214
215        assert 1 == len(p.pools)
216        assert pool is other_pool
217        assert all(isinstance(key, PoolKey) for key in p.pools.keys())
218
219    def test_https_connection_from_context_case_insensitive(self):
220        """Assert scheme case is ignored when getting the https key class."""
221        p = PoolManager()
222        context = {"scheme": "https", "host": "example.com", "port": "443"}
223        other_context = {"scheme": "HTTPS", "host": "EXAMPLE.COM", "port": "443"}
224        pool = p.connection_from_context(context)
225        other_pool = p.connection_from_context(other_context)
226
227        assert 1 == len(p.pools)
228        assert pool is other_pool
229        assert all(isinstance(key, PoolKey) for key in p.pools.keys())
230
231    def test_http_connection_from_url_case_insensitive(self):
232        """Assert scheme case is ignored when pooling HTTP connections."""
233        p = PoolManager()
234        pool = p.connection_from_url("http://example.com/")
235        other_pool = p.connection_from_url("HTTP://EXAMPLE.COM/")
236
237        assert 1 == len(p.pools)
238        assert pool is other_pool
239        assert all(isinstance(key, PoolKey) for key in p.pools.keys())
240
241    def test_http_connection_from_host_case_insensitive(self):
242        """Assert scheme case is ignored when getting the https key class."""
243        p = PoolManager()
244        pool = p.connection_from_host("example.com", scheme="http")
245        other_pool = p.connection_from_host("EXAMPLE.COM", scheme="HTTP")
246
247        assert 1 == len(p.pools)
248        assert pool is other_pool
249        assert all(isinstance(key, PoolKey) for key in p.pools.keys())
250
251    def test_assert_hostname_and_fingerprint_flag(self):
252        """Assert that pool manager can accept hostname and fingerprint flags."""
253        fingerprint = "92:81:FE:85:F7:0C:26:60:EC:D6:B3:BF:93:CF:F9:71:CC:07:7D:0A"
254        p = PoolManager(assert_hostname=True, assert_fingerprint=fingerprint)
255        pool = p.connection_from_url("https://example.com/")
256        assert 1 == len(p.pools)
257        assert pool.assert_hostname
258        assert fingerprint == pool.assert_fingerprint
259
260    def test_http_connection_from_context_case_insensitive(self):
261        """Assert scheme case is ignored when getting the https key class."""
262        p = PoolManager()
263        context = {"scheme": "http", "host": "example.com", "port": "8080"}
264        other_context = {"scheme": "HTTP", "host": "EXAMPLE.COM", "port": "8080"}
265        pool = p.connection_from_context(context)
266        other_pool = p.connection_from_context(other_context)
267
268        assert 1 == len(p.pools)
269        assert pool is other_pool
270        assert all(isinstance(key, PoolKey) for key in p.pools.keys())
271
272    def test_custom_pool_key(self):
273        """Assert it is possible to define a custom key function."""
274        p = PoolManager(10)
275
276        p.key_fn_by_scheme["http"] = lambda x: tuple(x["key"])
277        pool1 = p.connection_from_url(
278            "http://example.com", pool_kwargs={"key": "value"}
279        )
280        pool2 = p.connection_from_url(
281            "http://example.com", pool_kwargs={"key": "other"}
282        )
283        pool3 = p.connection_from_url(
284            "http://example.com", pool_kwargs={"key": "value", "x": "y"}
285        )
286
287        assert 2 == len(p.pools)
288        assert pool1 is pool3
289        assert pool1 is not pool2
290
291    def test_override_pool_kwargs_url(self):
292        """Assert overriding pool kwargs works with connection_from_url."""
293        p = PoolManager(strict=True)
294        pool_kwargs = {"strict": False, "retries": 100, "block": True}
295
296        default_pool = p.connection_from_url("http://example.com/")
297        override_pool = p.connection_from_url(
298            "http://example.com/", pool_kwargs=pool_kwargs
299        )
300
301        assert default_pool.strict
302        assert retry.Retry.DEFAULT == default_pool.retries
303        assert not default_pool.block
304
305        assert not override_pool.strict
306        assert 100 == override_pool.retries
307        assert override_pool.block
308
309    def test_override_pool_kwargs_host(self):
310        """Assert overriding pool kwargs works with connection_from_host"""
311        p = PoolManager(strict=True)
312        pool_kwargs = {"strict": False, "retries": 100, "block": True}
313
314        default_pool = p.connection_from_host("example.com", scheme="http")
315        override_pool = p.connection_from_host(
316            "example.com", scheme="http", pool_kwargs=pool_kwargs
317        )
318
319        assert default_pool.strict
320        assert retry.Retry.DEFAULT == default_pool.retries
321        assert not default_pool.block
322
323        assert not override_pool.strict
324        assert 100 == override_pool.retries
325        assert override_pool.block
326
327    def test_pool_kwargs_socket_options(self):
328        """Assert passing socket options works with connection_from_host"""
329        p = PoolManager(socket_options=[])
330        override_opts = [
331            (socket.SOL_SOCKET, socket.SO_REUSEADDR, 1),
332            (socket.IPPROTO_TCP, socket.TCP_NODELAY, 1),
333        ]
334        pool_kwargs = {"socket_options": override_opts}
335
336        default_pool = p.connection_from_host("example.com", scheme="http")
337        override_pool = p.connection_from_host(
338            "example.com", scheme="http", pool_kwargs=pool_kwargs
339        )
340
341        assert default_pool.conn_kw["socket_options"] == []
342        assert override_pool.conn_kw["socket_options"] == override_opts
343
344    def test_merge_pool_kwargs(self):
345        """Assert _merge_pool_kwargs works in the happy case"""
346        p = PoolManager(strict=True)
347        merged = p._merge_pool_kwargs({"new_key": "value"})
348        assert {"strict": True, "new_key": "value"} == merged
349
350    def test_merge_pool_kwargs_none(self):
351        """Assert false-y values to _merge_pool_kwargs result in defaults"""
352        p = PoolManager(strict=True)
353        merged = p._merge_pool_kwargs({})
354        assert p.connection_pool_kw == merged
355        merged = p._merge_pool_kwargs(None)
356        assert p.connection_pool_kw == merged
357
358    def test_merge_pool_kwargs_remove_key(self):
359        """Assert keys can be removed with _merge_pool_kwargs"""
360        p = PoolManager(strict=True)
361        merged = p._merge_pool_kwargs({"strict": None})
362        assert "strict" not in merged
363
364    def test_merge_pool_kwargs_invalid_key(self):
365        """Assert removing invalid keys with _merge_pool_kwargs doesn't break"""
366        p = PoolManager(strict=True)
367        merged = p._merge_pool_kwargs({"invalid_key": None})
368        assert p.connection_pool_kw == merged
369