1# Test the support for SSL and sockets
2
3import sys
4import unittest
5import unittest.mock
6from test import support
7import socket
8import select
9import time
10import datetime
11import gc
12import os
13import errno
14import pprint
15import urllib.request
16import threading
17import traceback
18import asyncore
19import weakref
20import platform
21import sysconfig
22import functools
23try:
24    import ctypes
25except ImportError:
26    ctypes = None
27
28ssl = support.import_module("ssl")
29
30from ssl import TLSVersion, _TLSContentType, _TLSMessageType
31
32PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
33HOST = support.HOST
34IS_LIBRESSL = ssl.OPENSSL_VERSION.startswith('LibreSSL')
35IS_OPENSSL_1_1_0 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0)
36IS_OPENSSL_1_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 1)
37IS_OPENSSL_3_0_0 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (3, 0, 0)
38PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS')
39
40PROTOCOL_TO_TLS_VERSION = {}
41for proto, ver in (
42    ("PROTOCOL_SSLv23", "SSLv3"),
43    ("PROTOCOL_TLSv1", "TLSv1"),
44    ("PROTOCOL_TLSv1_1", "TLSv1_1"),
45):
46    try:
47        proto = getattr(ssl, proto)
48        ver = getattr(ssl.TLSVersion, ver)
49    except AttributeError:
50        continue
51    PROTOCOL_TO_TLS_VERSION[proto] = ver
52
53def data_file(*name):
54    return os.path.join(os.path.dirname(__file__), *name)
55
56# The custom key and certificate files used in test_ssl are generated
57# using Lib/test/make_ssl_certs.py.
58# Other certificates are simply fetched from the Internet servers they
59# are meant to authenticate.
60
61CERTFILE = data_file("keycert.pem")
62BYTES_CERTFILE = os.fsencode(CERTFILE)
63ONLYCERT = data_file("ssl_cert.pem")
64ONLYKEY = data_file("ssl_key.pem")
65BYTES_ONLYCERT = os.fsencode(ONLYCERT)
66BYTES_ONLYKEY = os.fsencode(ONLYKEY)
67CERTFILE_PROTECTED = data_file("keycert.passwd.pem")
68ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem")
69KEY_PASSWORD = "somepass"
70CAPATH = data_file("capath")
71BYTES_CAPATH = os.fsencode(CAPATH)
72CAFILE_NEURONIO = data_file("capath", "4e1295a3.0")
73CAFILE_CACERT = data_file("capath", "5ed36f99.0")
74
75CERTFILE_INFO = {
76    'issuer': ((('countryName', 'XY'),),
77               (('localityName', 'Castle Anthrax'),),
78               (('organizationName', 'Python Software Foundation'),),
79               (('commonName', 'localhost'),)),
80    'notAfter': 'Aug 26 14:23:15 2028 GMT',
81    'notBefore': 'Aug 29 14:23:15 2018 GMT',
82    'serialNumber': '98A7CF88C74A32ED',
83    'subject': ((('countryName', 'XY'),),
84             (('localityName', 'Castle Anthrax'),),
85             (('organizationName', 'Python Software Foundation'),),
86             (('commonName', 'localhost'),)),
87    'subjectAltName': (('DNS', 'localhost'),),
88    'version': 3
89}
90
91# empty CRL
92CRLFILE = data_file("revocation.crl")
93
94# Two keys and certs signed by the same CA (for SNI tests)
95SIGNED_CERTFILE = data_file("keycert3.pem")
96SIGNED_CERTFILE_HOSTNAME = 'localhost'
97
98SIGNED_CERTFILE_INFO = {
99    'OCSP': ('http://testca.pythontest.net/testca/ocsp/',),
100    'caIssuers': ('http://testca.pythontest.net/testca/pycacert.cer',),
101    'crlDistributionPoints': ('http://testca.pythontest.net/testca/revocation.crl',),
102    'issuer': ((('countryName', 'XY'),),
103            (('organizationName', 'Python Software Foundation CA'),),
104            (('commonName', 'our-ca-server'),)),
105    'notAfter': 'Oct 28 14:23:16 2037 GMT',
106    'notBefore': 'Aug 29 14:23:16 2018 GMT',
107    'serialNumber': 'CB2D80995A69525C',
108    'subject': ((('countryName', 'XY'),),
109             (('localityName', 'Castle Anthrax'),),
110             (('organizationName', 'Python Software Foundation'),),
111             (('commonName', 'localhost'),)),
112    'subjectAltName': (('DNS', 'localhost'),),
113    'version': 3
114}
115
116SIGNED_CERTFILE2 = data_file("keycert4.pem")
117SIGNED_CERTFILE2_HOSTNAME = 'fakehostname'
118SIGNED_CERTFILE_ECC = data_file("keycertecc.pem")
119SIGNED_CERTFILE_ECC_HOSTNAME = 'localhost-ecc'
120
121# Same certificate as pycacert.pem, but without extra text in file
122SIGNING_CA = data_file("capath", "ceff1710.0")
123# cert with all kinds of subject alt names
124ALLSANFILE = data_file("allsans.pem")
125IDNSANSFILE = data_file("idnsans.pem")
126NOSANFILE = data_file("nosan.pem")
127NOSAN_HOSTNAME = 'localhost'
128
129REMOTE_HOST = "self-signed.pythontest.net"
130
131EMPTYCERT = data_file("nullcert.pem")
132BADCERT = data_file("badcert.pem")
133NONEXISTINGCERT = data_file("XXXnonexisting.pem")
134BADKEY = data_file("badkey.pem")
135NOKIACERT = data_file("nokia.pem")
136NULLBYTECERT = data_file("nullbytecert.pem")
137TALOS_INVALID_CRLDP = data_file("talos-2019-0758.pem")
138
139DHFILE = data_file("ffdh3072.pem")
140BYTES_DHFILE = os.fsencode(DHFILE)
141
142# Not defined in all versions of OpenSSL
143OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
144OP_SINGLE_DH_USE = getattr(ssl, "OP_SINGLE_DH_USE", 0)
145OP_SINGLE_ECDH_USE = getattr(ssl, "OP_SINGLE_ECDH_USE", 0)
146OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
147OP_ENABLE_MIDDLEBOX_COMPAT = getattr(ssl, "OP_ENABLE_MIDDLEBOX_COMPAT", 0)
148OP_IGNORE_UNEXPECTED_EOF = getattr(ssl, "OP_IGNORE_UNEXPECTED_EOF", 0)
149
150# Ubuntu has patched OpenSSL and changed behavior of security level 2
151# see https://bugs.python.org/issue41561#msg389003
152def is_ubuntu():
153    try:
154        # Assume that any references of "ubuntu" implies Ubuntu-like distro
155        # The workaround is not required for 18.04, but doesn't hurt either.
156        with open("/etc/os-release", encoding="utf-8") as f:
157            return "ubuntu" in f.read()
158    except FileNotFoundError:
159        return False
160
161if is_ubuntu():
162    def seclevel_workaround(*ctxs):
163        """"Lower security level to '1' and allow all ciphers for TLS 1.0/1"""
164        for ctx in ctxs:
165            if (
166                hasattr(ctx, "minimum_version") and
167                ctx.minimum_version <= ssl.TLSVersion.TLSv1_1
168            ):
169                ctx.set_ciphers("@SECLEVEL=1:ALL")
170else:
171    def seclevel_workaround(*ctxs):
172        pass
173
174
175def has_tls_protocol(protocol):
176    """Check if a TLS protocol is available and enabled
177
178    :param protocol: enum ssl._SSLMethod member or name
179    :return: bool
180    """
181    if isinstance(protocol, str):
182        assert protocol.startswith('PROTOCOL_')
183        protocol = getattr(ssl, protocol, None)
184        if protocol is None:
185            return False
186    if protocol in {
187        ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER,
188        ssl.PROTOCOL_TLS_CLIENT
189    }:
190        # auto-negotiate protocols are always available
191        return True
192    name = protocol.name
193    return has_tls_version(name[len('PROTOCOL_'):])
194
195
196@functools.lru_cache
197def has_tls_version(version):
198    """Check if a TLS/SSL version is enabled
199
200    :param version: TLS version name or ssl.TLSVersion member
201    :return: bool
202    """
203    if version == "SSLv2":
204        # never supported and not even in TLSVersion enum
205        return False
206
207    if isinstance(version, str):
208        version = ssl.TLSVersion.__members__[version]
209
210    # check compile time flags like ssl.HAS_TLSv1_2
211    if not getattr(ssl, f'HAS_{version.name}'):
212        return False
213
214    if IS_OPENSSL_3_0_0 and version < ssl.TLSVersion.TLSv1_2:
215        # bpo43791: 3.0.0-alpha14 fails with TLSV1_ALERT_INTERNAL_ERROR
216        return False
217
218    # check runtime and dynamic crypto policy settings. A TLS version may
219    # be compiled in but disabled by a policy or config option.
220    ctx = ssl.SSLContext()
221    if (
222            hasattr(ctx, 'minimum_version') and
223            ctx.minimum_version != ssl.TLSVersion.MINIMUM_SUPPORTED and
224            version < ctx.minimum_version
225    ):
226        return False
227    if (
228        hasattr(ctx, 'maximum_version') and
229        ctx.maximum_version != ssl.TLSVersion.MAXIMUM_SUPPORTED and
230        version > ctx.maximum_version
231    ):
232        return False
233
234    return True
235
236
237def requires_tls_version(version):
238    """Decorator to skip tests when a required TLS version is not available
239
240    :param version: TLS version name or ssl.TLSVersion member
241    :return:
242    """
243    def decorator(func):
244        @functools.wraps(func)
245        def wrapper(*args, **kw):
246            if not has_tls_version(version):
247                raise unittest.SkipTest(f"{version} is not available.")
248            else:
249                return func(*args, **kw)
250        return wrapper
251    return decorator
252
253
254requires_minimum_version = unittest.skipUnless(
255    hasattr(ssl.SSLContext, 'minimum_version'),
256    "required OpenSSL >= 1.1.0g"
257)
258
259
260def handle_error(prefix):
261    exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
262    if support.verbose:
263        sys.stdout.write(prefix + exc_format)
264
265def can_clear_options():
266    # 0.9.8m or higher
267    return ssl._OPENSSL_API_VERSION >= (0, 9, 8, 13, 15)
268
269def no_sslv2_implies_sslv3_hello():
270    # 0.9.7h or higher
271    return ssl.OPENSSL_VERSION_INFO >= (0, 9, 7, 8, 15)
272
273def have_verify_flags():
274    # 0.9.8 or higher
275    return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15)
276
277def _have_secp_curves():
278    if not ssl.HAS_ECDH:
279        return False
280    ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
281    try:
282        ctx.set_ecdh_curve("secp384r1")
283    except ValueError:
284        return False
285    else:
286        return True
287
288
289HAVE_SECP_CURVES = _have_secp_curves()
290
291
292def utc_offset(): #NOTE: ignore issues like #1647654
293    # local time = utc time + utc offset
294    if time.daylight and time.localtime().tm_isdst > 0:
295        return -time.altzone  # seconds
296    return -time.timezone
297
298def asn1time(cert_time):
299    # Some versions of OpenSSL ignore seconds, see #18207
300    # 0.9.8.i
301    if ssl._OPENSSL_API_VERSION == (0, 9, 8, 9, 15):
302        fmt = "%b %d %H:%M:%S %Y GMT"
303        dt = datetime.datetime.strptime(cert_time, fmt)
304        dt = dt.replace(second=0)
305        cert_time = dt.strftime(fmt)
306        # %d adds leading zero but ASN1_TIME_print() uses leading space
307        if cert_time[4] == "0":
308            cert_time = cert_time[:4] + " " + cert_time[5:]
309
310    return cert_time
311
312needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test")
313
314
315def test_wrap_socket(sock, ssl_version=ssl.PROTOCOL_TLS, *,
316                     cert_reqs=ssl.CERT_NONE, ca_certs=None,
317                     ciphers=None, certfile=None, keyfile=None,
318                     **kwargs):
319    context = ssl.SSLContext(ssl_version)
320    if cert_reqs is not None:
321        if cert_reqs == ssl.CERT_NONE:
322            context.check_hostname = False
323        context.verify_mode = cert_reqs
324    if ca_certs is not None:
325        context.load_verify_locations(ca_certs)
326    if certfile is not None or keyfile is not None:
327        context.load_cert_chain(certfile, keyfile)
328    if ciphers is not None:
329        context.set_ciphers(ciphers)
330    return context.wrap_socket(sock, **kwargs)
331
332
333def testing_context(server_cert=SIGNED_CERTFILE):
334    """Create context
335
336    client_context, server_context, hostname = testing_context()
337    """
338    if server_cert == SIGNED_CERTFILE:
339        hostname = SIGNED_CERTFILE_HOSTNAME
340    elif server_cert == SIGNED_CERTFILE2:
341        hostname = SIGNED_CERTFILE2_HOSTNAME
342    elif server_cert == NOSANFILE:
343        hostname = NOSAN_HOSTNAME
344    else:
345        raise ValueError(server_cert)
346
347    client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
348    client_context.load_verify_locations(SIGNING_CA)
349
350    server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
351    server_context.load_cert_chain(server_cert)
352    server_context.load_verify_locations(SIGNING_CA)
353
354    return client_context, server_context, hostname
355
356
357class BasicSocketTests(unittest.TestCase):
358
359    def test_constants(self):
360        ssl.CERT_NONE
361        ssl.CERT_OPTIONAL
362        ssl.CERT_REQUIRED
363        ssl.OP_CIPHER_SERVER_PREFERENCE
364        ssl.OP_SINGLE_DH_USE
365        if ssl.HAS_ECDH:
366            ssl.OP_SINGLE_ECDH_USE
367        if ssl.OPENSSL_VERSION_INFO >= (1, 0):
368            ssl.OP_NO_COMPRESSION
369        self.assertIn(ssl.HAS_SNI, {True, False})
370        self.assertIn(ssl.HAS_ECDH, {True, False})
371        ssl.OP_NO_SSLv2
372        ssl.OP_NO_SSLv3
373        ssl.OP_NO_TLSv1
374        ssl.OP_NO_TLSv1_3
375        if ssl.OPENSSL_VERSION_INFO >= (1, 0, 1):
376            ssl.OP_NO_TLSv1_1
377            ssl.OP_NO_TLSv1_2
378        self.assertEqual(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv23)
379
380    def test_private_init(self):
381        with self.assertRaisesRegex(TypeError, "public constructor"):
382            with socket.socket() as s:
383                ssl.SSLSocket(s)
384
385    def test_str_for_enums(self):
386        # Make sure that the PROTOCOL_* constants have enum-like string
387        # reprs.
388        proto = ssl.PROTOCOL_TLS
389        self.assertEqual(str(proto), '_SSLMethod.PROTOCOL_TLS')
390        ctx = ssl.SSLContext(proto)
391        self.assertIs(ctx.protocol, proto)
392
393    def test_random(self):
394        v = ssl.RAND_status()
395        if support.verbose:
396            sys.stdout.write("\n RAND_status is %d (%s)\n"
397                             % (v, (v and "sufficient randomness") or
398                                "insufficient randomness"))
399
400        data, is_cryptographic = ssl.RAND_pseudo_bytes(16)
401        self.assertEqual(len(data), 16)
402        self.assertEqual(is_cryptographic, v == 1)
403        if v:
404            data = ssl.RAND_bytes(16)
405            self.assertEqual(len(data), 16)
406        else:
407            self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16)
408
409        # negative num is invalid
410        self.assertRaises(ValueError, ssl.RAND_bytes, -5)
411        self.assertRaises(ValueError, ssl.RAND_pseudo_bytes, -5)
412
413        if hasattr(ssl, 'RAND_egd'):
414            self.assertRaises(TypeError, ssl.RAND_egd, 1)
415            self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1)
416        ssl.RAND_add("this is a random string", 75.0)
417        ssl.RAND_add(b"this is a random bytes object", 75.0)
418        ssl.RAND_add(bytearray(b"this is a random bytearray object"), 75.0)
419
420    @unittest.skipUnless(os.name == 'posix', 'requires posix')
421    def test_random_fork(self):
422        status = ssl.RAND_status()
423        if not status:
424            self.fail("OpenSSL's PRNG has insufficient randomness")
425
426        rfd, wfd = os.pipe()
427        pid = os.fork()
428        if pid == 0:
429            try:
430                os.close(rfd)
431                child_random = ssl.RAND_pseudo_bytes(16)[0]
432                self.assertEqual(len(child_random), 16)
433                os.write(wfd, child_random)
434                os.close(wfd)
435            except BaseException:
436                os._exit(1)
437            else:
438                os._exit(0)
439        else:
440            os.close(wfd)
441            self.addCleanup(os.close, rfd)
442            _, status = os.waitpid(pid, 0)
443            self.assertEqual(status, 0)
444
445            child_random = os.read(rfd, 16)
446            self.assertEqual(len(child_random), 16)
447            parent_random = ssl.RAND_pseudo_bytes(16)[0]
448            self.assertEqual(len(parent_random), 16)
449
450            self.assertNotEqual(child_random, parent_random)
451
452    maxDiff = None
453
454    def test_parse_cert(self):
455        # note that this uses an 'unofficial' function in _ssl.c,
456        # provided solely for this test, to exercise the certificate
457        # parsing code
458        self.assertEqual(
459            ssl._ssl._test_decode_cert(CERTFILE),
460            CERTFILE_INFO
461        )
462        self.assertEqual(
463            ssl._ssl._test_decode_cert(SIGNED_CERTFILE),
464            SIGNED_CERTFILE_INFO
465        )
466
467        # Issue #13034: the subjectAltName in some certificates
468        # (notably projects.developer.nokia.com:443) wasn't parsed
469        p = ssl._ssl._test_decode_cert(NOKIACERT)
470        if support.verbose:
471            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
472        self.assertEqual(p['subjectAltName'],
473                         (('DNS', 'projects.developer.nokia.com'),
474                          ('DNS', 'projects.forum.nokia.com'))
475                        )
476        # extra OCSP and AIA fields
477        self.assertEqual(p['OCSP'], ('http://ocsp.verisign.com',))
478        self.assertEqual(p['caIssuers'],
479                         ('http://SVRIntl-G3-aia.verisign.com/SVRIntlG3.cer',))
480        self.assertEqual(p['crlDistributionPoints'],
481                         ('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',))
482
483    def test_parse_cert_CVE_2019_5010(self):
484        p = ssl._ssl._test_decode_cert(TALOS_INVALID_CRLDP)
485        if support.verbose:
486            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
487        self.assertEqual(
488            p,
489            {
490                'issuer': (
491                    (('countryName', 'UK'),), (('commonName', 'cody-ca'),)),
492                'notAfter': 'Jun 14 18:00:58 2028 GMT',
493                'notBefore': 'Jun 18 18:00:58 2018 GMT',
494                'serialNumber': '02',
495                'subject': ((('countryName', 'UK'),),
496                            (('commonName',
497                              'codenomicon-vm-2.test.lal.cisco.com'),)),
498                'subjectAltName': (
499                    ('DNS', 'codenomicon-vm-2.test.lal.cisco.com'),),
500                'version': 3
501            }
502        )
503
504    def test_parse_cert_CVE_2013_4238(self):
505        p = ssl._ssl._test_decode_cert(NULLBYTECERT)
506        if support.verbose:
507            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
508        subject = ((('countryName', 'US'),),
509                   (('stateOrProvinceName', 'Oregon'),),
510                   (('localityName', 'Beaverton'),),
511                   (('organizationName', 'Python Software Foundation'),),
512                   (('organizationalUnitName', 'Python Core Development'),),
513                   (('commonName', 'null.python.org\x00example.org'),),
514                   (('emailAddress', 'python-dev@python.org'),))
515        self.assertEqual(p['subject'], subject)
516        self.assertEqual(p['issuer'], subject)
517        if ssl._OPENSSL_API_VERSION >= (0, 9, 8):
518            san = (('DNS', 'altnull.python.org\x00example.com'),
519                   ('email', 'null@python.org\x00user@example.org'),
520                   ('URI', 'http://null.python.org\x00http://example.org'),
521                   ('IP Address', '192.0.2.1'),
522                   ('IP Address', '2001:DB8:0:0:0:0:0:1'))
523        else:
524            # OpenSSL 0.9.7 doesn't support IPv6 addresses in subjectAltName
525            san = (('DNS', 'altnull.python.org\x00example.com'),
526                   ('email', 'null@python.org\x00user@example.org'),
527                   ('URI', 'http://null.python.org\x00http://example.org'),
528                   ('IP Address', '192.0.2.1'),
529                   ('IP Address', '<invalid>'))
530
531        self.assertEqual(p['subjectAltName'], san)
532
533    def test_parse_all_sans(self):
534        p = ssl._ssl._test_decode_cert(ALLSANFILE)
535        self.assertEqual(p['subjectAltName'],
536            (
537                ('DNS', 'allsans'),
538                ('othername', '<unsupported>'),
539                ('othername', '<unsupported>'),
540                ('email', 'user@example.org'),
541                ('DNS', 'www.example.org'),
542                ('DirName',
543                    ((('countryName', 'XY'),),
544                    (('localityName', 'Castle Anthrax'),),
545                    (('organizationName', 'Python Software Foundation'),),
546                    (('commonName', 'dirname example'),))),
547                ('URI', 'https://www.python.org/'),
548                ('IP Address', '127.0.0.1'),
549                ('IP Address', '0:0:0:0:0:0:0:1'),
550                ('Registered ID', '1.2.3.4.5')
551            )
552        )
553
554    def test_DER_to_PEM(self):
555        with open(CAFILE_CACERT, 'r') as f:
556            pem = f.read()
557        d1 = ssl.PEM_cert_to_DER_cert(pem)
558        p2 = ssl.DER_cert_to_PEM_cert(d1)
559        d2 = ssl.PEM_cert_to_DER_cert(p2)
560        self.assertEqual(d1, d2)
561        if not p2.startswith(ssl.PEM_HEADER + '\n'):
562            self.fail("DER-to-PEM didn't include correct header:\n%r\n" % p2)
563        if not p2.endswith('\n' + ssl.PEM_FOOTER + '\n'):
564            self.fail("DER-to-PEM didn't include correct footer:\n%r\n" % p2)
565
566    def test_openssl_version(self):
567        n = ssl.OPENSSL_VERSION_NUMBER
568        t = ssl.OPENSSL_VERSION_INFO
569        s = ssl.OPENSSL_VERSION
570        self.assertIsInstance(n, int)
571        self.assertIsInstance(t, tuple)
572        self.assertIsInstance(s, str)
573        # Some sanity checks follow
574        # >= 0.9
575        self.assertGreaterEqual(n, 0x900000)
576        # < 4.0
577        self.assertLess(n, 0x40000000)
578        major, minor, fix, patch, status = t
579        self.assertGreaterEqual(major, 1)
580        self.assertLess(major, 4)
581        self.assertGreaterEqual(minor, 0)
582        self.assertLess(minor, 256)
583        self.assertGreaterEqual(fix, 0)
584        self.assertLess(fix, 256)
585        self.assertGreaterEqual(patch, 0)
586        self.assertLessEqual(patch, 63)
587        self.assertGreaterEqual(status, 0)
588        self.assertLessEqual(status, 15)
589        # Version string as returned by {Open,Libre}SSL, the format might change
590        if IS_LIBRESSL:
591            self.assertTrue(s.startswith("LibreSSL {:d}".format(major)),
592                            (s, t, hex(n)))
593        else:
594            self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
595                            (s, t, hex(n)))
596
597    @support.cpython_only
598    def test_refcycle(self):
599        # Issue #7943: an SSL object doesn't create reference cycles with
600        # itself.
601        s = socket.socket(socket.AF_INET)
602        ss = test_wrap_socket(s)
603        wr = weakref.ref(ss)
604        with support.check_warnings(("", ResourceWarning)):
605            del ss
606        self.assertEqual(wr(), None)
607
608    def test_wrapped_unconnected(self):
609        # Methods on an unconnected SSLSocket propagate the original
610        # OSError raise by the underlying socket object.
611        s = socket.socket(socket.AF_INET)
612        with test_wrap_socket(s) as ss:
613            self.assertRaises(OSError, ss.recv, 1)
614            self.assertRaises(OSError, ss.recv_into, bytearray(b'x'))
615            self.assertRaises(OSError, ss.recvfrom, 1)
616            self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
617            self.assertRaises(OSError, ss.send, b'x')
618            self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
619            self.assertRaises(NotImplementedError, ss.dup)
620            self.assertRaises(NotImplementedError, ss.sendmsg,
621                              [b'x'], (), 0, ('0.0.0.0', 0))
622            self.assertRaises(NotImplementedError, ss.recvmsg, 100)
623            self.assertRaises(NotImplementedError, ss.recvmsg_into,
624                              [bytearray(100)])
625
626    def test_timeout(self):
627        # Issue #8524: when creating an SSL socket, the timeout of the
628        # original socket should be retained.
629        for timeout in (None, 0.0, 5.0):
630            s = socket.socket(socket.AF_INET)
631            s.settimeout(timeout)
632            with test_wrap_socket(s) as ss:
633                self.assertEqual(timeout, ss.gettimeout())
634
635    def test_errors_sslwrap(self):
636        sock = socket.socket()
637        self.assertRaisesRegex(ValueError,
638                        "certfile must be specified",
639                        ssl.wrap_socket, sock, keyfile=CERTFILE)
640        self.assertRaisesRegex(ValueError,
641                        "certfile must be specified for server-side operations",
642                        ssl.wrap_socket, sock, server_side=True)
643        self.assertRaisesRegex(ValueError,
644                        "certfile must be specified for server-side operations",
645                         ssl.wrap_socket, sock, server_side=True, certfile="")
646        with ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE) as s:
647            self.assertRaisesRegex(ValueError, "can't connect in server-side mode",
648                                     s.connect, (HOST, 8080))
649        with self.assertRaises(OSError) as cm:
650            with socket.socket() as sock:
651                ssl.wrap_socket(sock, certfile=NONEXISTINGCERT)
652        self.assertEqual(cm.exception.errno, errno.ENOENT)
653        with self.assertRaises(OSError) as cm:
654            with socket.socket() as sock:
655                ssl.wrap_socket(sock,
656                    certfile=CERTFILE, keyfile=NONEXISTINGCERT)
657        self.assertEqual(cm.exception.errno, errno.ENOENT)
658        with self.assertRaises(OSError) as cm:
659            with socket.socket() as sock:
660                ssl.wrap_socket(sock,
661                    certfile=NONEXISTINGCERT, keyfile=NONEXISTINGCERT)
662        self.assertEqual(cm.exception.errno, errno.ENOENT)
663
664    def bad_cert_test(self, certfile):
665        """Check that trying to use the given client certificate fails"""
666        certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
667                                   certfile)
668        sock = socket.socket()
669        self.addCleanup(sock.close)
670        with self.assertRaises(ssl.SSLError):
671            test_wrap_socket(sock,
672                             certfile=certfile)
673
674    def test_empty_cert(self):
675        """Wrapping with an empty cert file"""
676        self.bad_cert_test("nullcert.pem")
677
678    def test_malformed_cert(self):
679        """Wrapping with a badly formatted certificate (syntax error)"""
680        self.bad_cert_test("badcert.pem")
681
682    def test_malformed_key(self):
683        """Wrapping with a badly formatted key (syntax error)"""
684        self.bad_cert_test("badkey.pem")
685
686    def test_match_hostname(self):
687        def ok(cert, hostname):
688            ssl.match_hostname(cert, hostname)
689        def fail(cert, hostname):
690            self.assertRaises(ssl.CertificateError,
691                              ssl.match_hostname, cert, hostname)
692
693        # -- Hostname matching --
694
695        cert = {'subject': ((('commonName', 'example.com'),),)}
696        ok(cert, 'example.com')
697        ok(cert, 'ExAmple.cOm')
698        fail(cert, 'www.example.com')
699        fail(cert, '.example.com')
700        fail(cert, 'example.org')
701        fail(cert, 'exampleXcom')
702
703        cert = {'subject': ((('commonName', '*.a.com'),),)}
704        ok(cert, 'foo.a.com')
705        fail(cert, 'bar.foo.a.com')
706        fail(cert, 'a.com')
707        fail(cert, 'Xa.com')
708        fail(cert, '.a.com')
709
710        # only match wildcards when they are the only thing
711        # in left-most segment
712        cert = {'subject': ((('commonName', 'f*.com'),),)}
713        fail(cert, 'foo.com')
714        fail(cert, 'f.com')
715        fail(cert, 'bar.com')
716        fail(cert, 'foo.a.com')
717        fail(cert, 'bar.foo.com')
718
719        # NULL bytes are bad, CVE-2013-4073
720        cert = {'subject': ((('commonName',
721                              'null.python.org\x00example.org'),),)}
722        ok(cert, 'null.python.org\x00example.org') # or raise an error?
723        fail(cert, 'example.org')
724        fail(cert, 'null.python.org')
725
726        # error cases with wildcards
727        cert = {'subject': ((('commonName', '*.*.a.com'),),)}
728        fail(cert, 'bar.foo.a.com')
729        fail(cert, 'a.com')
730        fail(cert, 'Xa.com')
731        fail(cert, '.a.com')
732
733        cert = {'subject': ((('commonName', 'a.*.com'),),)}
734        fail(cert, 'a.foo.com')
735        fail(cert, 'a..com')
736        fail(cert, 'a.com')
737
738        # wildcard doesn't match IDNA prefix 'xn--'
739        idna = 'püthon.python.org'.encode("idna").decode("ascii")
740        cert = {'subject': ((('commonName', idna),),)}
741        ok(cert, idna)
742        cert = {'subject': ((('commonName', 'x*.python.org'),),)}
743        fail(cert, idna)
744        cert = {'subject': ((('commonName', 'xn--p*.python.org'),),)}
745        fail(cert, idna)
746
747        # wildcard in first fragment and  IDNA A-labels in sequent fragments
748        # are supported.
749        idna = 'www*.pythön.org'.encode("idna").decode("ascii")
750        cert = {'subject': ((('commonName', idna),),)}
751        fail(cert, 'www.pythön.org'.encode("idna").decode("ascii"))
752        fail(cert, 'www1.pythön.org'.encode("idna").decode("ascii"))
753        fail(cert, 'ftp.pythön.org'.encode("idna").decode("ascii"))
754        fail(cert, 'pythön.org'.encode("idna").decode("ascii"))
755
756        # Slightly fake real-world example
757        cert = {'notAfter': 'Jun 26 21:41:46 2011 GMT',
758                'subject': ((('commonName', 'linuxfrz.org'),),),
759                'subjectAltName': (('DNS', 'linuxfr.org'),
760                                   ('DNS', 'linuxfr.com'),
761                                   ('othername', '<unsupported>'))}
762        ok(cert, 'linuxfr.org')
763        ok(cert, 'linuxfr.com')
764        # Not a "DNS" entry
765        fail(cert, '<unsupported>')
766        # When there is a subjectAltName, commonName isn't used
767        fail(cert, 'linuxfrz.org')
768
769        # A pristine real-world example
770        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
771                'subject': ((('countryName', 'US'),),
772                            (('stateOrProvinceName', 'California'),),
773                            (('localityName', 'Mountain View'),),
774                            (('organizationName', 'Google Inc'),),
775                            (('commonName', 'mail.google.com'),))}
776        ok(cert, 'mail.google.com')
777        fail(cert, 'gmail.com')
778        # Only commonName is considered
779        fail(cert, 'California')
780
781        # -- IPv4 matching --
782        cert = {'subject': ((('commonName', 'example.com'),),),
783                'subjectAltName': (('DNS', 'example.com'),
784                                   ('IP Address', '10.11.12.13'),
785                                   ('IP Address', '14.15.16.17'),
786                                   ('IP Address', '127.0.0.1'))}
787        ok(cert, '10.11.12.13')
788        ok(cert, '14.15.16.17')
789        # socket.inet_ntoa(socket.inet_aton('127.1')) == '127.0.0.1'
790        fail(cert, '127.1')
791        fail(cert, '14.15.16.17 ')
792        fail(cert, '14.15.16.17 extra data')
793        fail(cert, '14.15.16.18')
794        fail(cert, 'example.net')
795
796        # -- IPv6 matching --
797        if support.IPV6_ENABLED:
798            cert = {'subject': ((('commonName', 'example.com'),),),
799                    'subjectAltName': (
800                        ('DNS', 'example.com'),
801                        ('IP Address', '2001:0:0:0:0:0:0:CAFE\n'),
802                        ('IP Address', '2003:0:0:0:0:0:0:BABA\n'))}
803            ok(cert, '2001::cafe')
804            ok(cert, '2003::baba')
805            fail(cert, '2003::baba ')
806            fail(cert, '2003::baba extra data')
807            fail(cert, '2003::bebe')
808            fail(cert, 'example.net')
809
810        # -- Miscellaneous --
811
812        # Neither commonName nor subjectAltName
813        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
814                'subject': ((('countryName', 'US'),),
815                            (('stateOrProvinceName', 'California'),),
816                            (('localityName', 'Mountain View'),),
817                            (('organizationName', 'Google Inc'),))}
818        fail(cert, 'mail.google.com')
819
820        # No DNS entry in subjectAltName but a commonName
821        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
822                'subject': ((('countryName', 'US'),),
823                            (('stateOrProvinceName', 'California'),),
824                            (('localityName', 'Mountain View'),),
825                            (('commonName', 'mail.google.com'),)),
826                'subjectAltName': (('othername', 'blabla'), )}
827        ok(cert, 'mail.google.com')
828
829        # No DNS entry subjectAltName and no commonName
830        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
831                'subject': ((('countryName', 'US'),),
832                            (('stateOrProvinceName', 'California'),),
833                            (('localityName', 'Mountain View'),),
834                            (('organizationName', 'Google Inc'),)),
835                'subjectAltName': (('othername', 'blabla'),)}
836        fail(cert, 'google.com')
837
838        # Empty cert / no cert
839        self.assertRaises(ValueError, ssl.match_hostname, None, 'example.com')
840        self.assertRaises(ValueError, ssl.match_hostname, {}, 'example.com')
841
842        # Issue #17980: avoid denials of service by refusing more than one
843        # wildcard per fragment.
844        cert = {'subject': ((('commonName', 'a*b.example.com'),),)}
845        with self.assertRaisesRegex(
846                ssl.CertificateError,
847                "partial wildcards in leftmost label are not supported"):
848            ssl.match_hostname(cert, 'axxb.example.com')
849
850        cert = {'subject': ((('commonName', 'www.*.example.com'),),)}
851        with self.assertRaisesRegex(
852                ssl.CertificateError,
853                "wildcard can only be present in the leftmost label"):
854            ssl.match_hostname(cert, 'www.sub.example.com')
855
856        cert = {'subject': ((('commonName', 'a*b*.example.com'),),)}
857        with self.assertRaisesRegex(
858                ssl.CertificateError,
859                "too many wildcards"):
860            ssl.match_hostname(cert, 'axxbxxc.example.com')
861
862        cert = {'subject': ((('commonName', '*'),),)}
863        with self.assertRaisesRegex(
864                ssl.CertificateError,
865                "sole wildcard without additional labels are not support"):
866            ssl.match_hostname(cert, 'host')
867
868        cert = {'subject': ((('commonName', '*.com'),),)}
869        with self.assertRaisesRegex(
870                ssl.CertificateError,
871                r"hostname 'com' doesn't match '\*.com'"):
872            ssl.match_hostname(cert, 'com')
873
874        # extra checks for _inet_paton()
875        for invalid in ['1', '', '1.2.3', '256.0.0.1', '127.0.0.1/24']:
876            with self.assertRaises(ValueError):
877                ssl._inet_paton(invalid)
878        for ipaddr in ['127.0.0.1', '192.168.0.1']:
879            self.assertTrue(ssl._inet_paton(ipaddr))
880        if support.IPV6_ENABLED:
881            for ipaddr in ['::1', '2001:db8:85a3::8a2e:370:7334']:
882                self.assertTrue(ssl._inet_paton(ipaddr))
883
884    def test_server_side(self):
885        # server_hostname doesn't work for server sockets
886        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
887        with socket.socket() as sock:
888            self.assertRaises(ValueError, ctx.wrap_socket, sock, True,
889                              server_hostname="some.hostname")
890
891    def test_unknown_channel_binding(self):
892        # should raise ValueError for unknown type
893        s = socket.create_server(('127.0.0.1', 0))
894        c = socket.socket(socket.AF_INET)
895        c.connect(s.getsockname())
896        with test_wrap_socket(c, do_handshake_on_connect=False) as ss:
897            with self.assertRaises(ValueError):
898                ss.get_channel_binding("unknown-type")
899        s.close()
900
901    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
902                         "'tls-unique' channel binding not available")
903    def test_tls_unique_channel_binding(self):
904        # unconnected should return None for known type
905        s = socket.socket(socket.AF_INET)
906        with test_wrap_socket(s) as ss:
907            self.assertIsNone(ss.get_channel_binding("tls-unique"))
908        # the same for server-side
909        s = socket.socket(socket.AF_INET)
910        with test_wrap_socket(s, server_side=True, certfile=CERTFILE) as ss:
911            self.assertIsNone(ss.get_channel_binding("tls-unique"))
912
913    def test_dealloc_warn(self):
914        ss = test_wrap_socket(socket.socket(socket.AF_INET))
915        r = repr(ss)
916        with self.assertWarns(ResourceWarning) as cm:
917            ss = None
918            support.gc_collect()
919        self.assertIn(r, str(cm.warning.args[0]))
920
921    def test_get_default_verify_paths(self):
922        paths = ssl.get_default_verify_paths()
923        self.assertEqual(len(paths), 6)
924        self.assertIsInstance(paths, ssl.DefaultVerifyPaths)
925
926        with support.EnvironmentVarGuard() as env:
927            env["SSL_CERT_DIR"] = CAPATH
928            env["SSL_CERT_FILE"] = CERTFILE
929            paths = ssl.get_default_verify_paths()
930            self.assertEqual(paths.cafile, CERTFILE)
931            self.assertEqual(paths.capath, CAPATH)
932
933    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
934    def test_enum_certificates(self):
935        self.assertTrue(ssl.enum_certificates("CA"))
936        self.assertTrue(ssl.enum_certificates("ROOT"))
937
938        self.assertRaises(TypeError, ssl.enum_certificates)
939        self.assertRaises(WindowsError, ssl.enum_certificates, "")
940
941        trust_oids = set()
942        for storename in ("CA", "ROOT"):
943            store = ssl.enum_certificates(storename)
944            self.assertIsInstance(store, list)
945            for element in store:
946                self.assertIsInstance(element, tuple)
947                self.assertEqual(len(element), 3)
948                cert, enc, trust = element
949                self.assertIsInstance(cert, bytes)
950                self.assertIn(enc, {"x509_asn", "pkcs_7_asn"})
951                self.assertIsInstance(trust, (frozenset, set, bool))
952                if isinstance(trust, (frozenset, set)):
953                    trust_oids.update(trust)
954
955        serverAuth = "1.3.6.1.5.5.7.3.1"
956        self.assertIn(serverAuth, trust_oids)
957
958    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
959    def test_enum_crls(self):
960        self.assertTrue(ssl.enum_crls("CA"))
961        self.assertRaises(TypeError, ssl.enum_crls)
962        self.assertRaises(WindowsError, ssl.enum_crls, "")
963
964        crls = ssl.enum_crls("CA")
965        self.assertIsInstance(crls, list)
966        for element in crls:
967            self.assertIsInstance(element, tuple)
968            self.assertEqual(len(element), 2)
969            self.assertIsInstance(element[0], bytes)
970            self.assertIn(element[1], {"x509_asn", "pkcs_7_asn"})
971
972
973    def test_asn1object(self):
974        expected = (129, 'serverAuth', 'TLS Web Server Authentication',
975                    '1.3.6.1.5.5.7.3.1')
976
977        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
978        self.assertEqual(val, expected)
979        self.assertEqual(val.nid, 129)
980        self.assertEqual(val.shortname, 'serverAuth')
981        self.assertEqual(val.longname, 'TLS Web Server Authentication')
982        self.assertEqual(val.oid, '1.3.6.1.5.5.7.3.1')
983        self.assertIsInstance(val, ssl._ASN1Object)
984        self.assertRaises(ValueError, ssl._ASN1Object, 'serverAuth')
985
986        val = ssl._ASN1Object.fromnid(129)
987        self.assertEqual(val, expected)
988        self.assertIsInstance(val, ssl._ASN1Object)
989        self.assertRaises(ValueError, ssl._ASN1Object.fromnid, -1)
990        with self.assertRaisesRegex(ValueError, "unknown NID 100000"):
991            ssl._ASN1Object.fromnid(100000)
992        for i in range(1000):
993            try:
994                obj = ssl._ASN1Object.fromnid(i)
995            except ValueError:
996                pass
997            else:
998                self.assertIsInstance(obj.nid, int)
999                self.assertIsInstance(obj.shortname, str)
1000                self.assertIsInstance(obj.longname, str)
1001                self.assertIsInstance(obj.oid, (str, type(None)))
1002
1003        val = ssl._ASN1Object.fromname('TLS Web Server Authentication')
1004        self.assertEqual(val, expected)
1005        self.assertIsInstance(val, ssl._ASN1Object)
1006        self.assertEqual(ssl._ASN1Object.fromname('serverAuth'), expected)
1007        self.assertEqual(ssl._ASN1Object.fromname('1.3.6.1.5.5.7.3.1'),
1008                         expected)
1009        with self.assertRaisesRegex(ValueError, "unknown object 'serverauth'"):
1010            ssl._ASN1Object.fromname('serverauth')
1011
1012    def test_purpose_enum(self):
1013        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
1014        self.assertIsInstance(ssl.Purpose.SERVER_AUTH, ssl._ASN1Object)
1015        self.assertEqual(ssl.Purpose.SERVER_AUTH, val)
1016        self.assertEqual(ssl.Purpose.SERVER_AUTH.nid, 129)
1017        self.assertEqual(ssl.Purpose.SERVER_AUTH.shortname, 'serverAuth')
1018        self.assertEqual(ssl.Purpose.SERVER_AUTH.oid,
1019                              '1.3.6.1.5.5.7.3.1')
1020
1021        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.2')
1022        self.assertIsInstance(ssl.Purpose.CLIENT_AUTH, ssl._ASN1Object)
1023        self.assertEqual(ssl.Purpose.CLIENT_AUTH, val)
1024        self.assertEqual(ssl.Purpose.CLIENT_AUTH.nid, 130)
1025        self.assertEqual(ssl.Purpose.CLIENT_AUTH.shortname, 'clientAuth')
1026        self.assertEqual(ssl.Purpose.CLIENT_AUTH.oid,
1027                              '1.3.6.1.5.5.7.3.2')
1028
1029    def test_unsupported_dtls(self):
1030        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
1031        self.addCleanup(s.close)
1032        with self.assertRaises(NotImplementedError) as cx:
1033            test_wrap_socket(s, cert_reqs=ssl.CERT_NONE)
1034        self.assertEqual(str(cx.exception), "only stream sockets are supported")
1035        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1036        with self.assertRaises(NotImplementedError) as cx:
1037            ctx.wrap_socket(s)
1038        self.assertEqual(str(cx.exception), "only stream sockets are supported")
1039
1040    def cert_time_ok(self, timestring, timestamp):
1041        self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp)
1042
1043    def cert_time_fail(self, timestring):
1044        with self.assertRaises(ValueError):
1045            ssl.cert_time_to_seconds(timestring)
1046
1047    @unittest.skipUnless(utc_offset(),
1048                         'local time needs to be different from UTC')
1049    def test_cert_time_to_seconds_timezone(self):
1050        # Issue #19940: ssl.cert_time_to_seconds() returns wrong
1051        #               results if local timezone is not UTC
1052        self.cert_time_ok("May  9 00:00:00 2007 GMT", 1178668800.0)
1053        self.cert_time_ok("Jan  5 09:34:43 2018 GMT", 1515144883.0)
1054
1055    def test_cert_time_to_seconds(self):
1056        timestring = "Jan  5 09:34:43 2018 GMT"
1057        ts = 1515144883.0
1058        self.cert_time_ok(timestring, ts)
1059        # accept keyword parameter, assert its name
1060        self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts)
1061        # accept both %e and %d (space or zero generated by strftime)
1062        self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts)
1063        # case-insensitive
1064        self.cert_time_ok("JaN  5 09:34:43 2018 GmT", ts)
1065        self.cert_time_fail("Jan  5 09:34 2018 GMT")     # no seconds
1066        self.cert_time_fail("Jan  5 09:34:43 2018")      # no GMT
1067        self.cert_time_fail("Jan  5 09:34:43 2018 UTC")  # not GMT timezone
1068        self.cert_time_fail("Jan 35 09:34:43 2018 GMT")  # invalid day
1069        self.cert_time_fail("Jon  5 09:34:43 2018 GMT")  # invalid month
1070        self.cert_time_fail("Jan  5 24:00:00 2018 GMT")  # invalid hour
1071        self.cert_time_fail("Jan  5 09:60:43 2018 GMT")  # invalid minute
1072
1073        newyear_ts = 1230768000.0
1074        # leap seconds
1075        self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts)
1076        # same timestamp
1077        self.cert_time_ok("Jan  1 00:00:00 2009 GMT", newyear_ts)
1078
1079        self.cert_time_ok("Jan  5 09:34:59 2018 GMT", 1515144899)
1080        #  allow 60th second (even if it is not a leap second)
1081        self.cert_time_ok("Jan  5 09:34:60 2018 GMT", 1515144900)
1082        #  allow 2nd leap second for compatibility with time.strptime()
1083        self.cert_time_ok("Jan  5 09:34:61 2018 GMT", 1515144901)
1084        self.cert_time_fail("Jan  5 09:34:62 2018 GMT")  # invalid seconds
1085
1086        # no special treatment for the special value:
1087        #   99991231235959Z (rfc 5280)
1088        self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
1089
1090    @support.run_with_locale('LC_ALL', '')
1091    def test_cert_time_to_seconds_locale(self):
1092        # `cert_time_to_seconds()` should be locale independent
1093
1094        def local_february_name():
1095            return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0))
1096
1097        if local_february_name().lower() == 'feb':
1098            self.skipTest("locale-specific month name needs to be "
1099                          "different from C locale")
1100
1101        # locale-independent
1102        self.cert_time_ok("Feb  9 00:00:00 2007 GMT", 1170979200.0)
1103        self.cert_time_fail(local_february_name() + "  9 00:00:00 2007 GMT")
1104
1105    def test_connect_ex_error(self):
1106        server = socket.socket(socket.AF_INET)
1107        self.addCleanup(server.close)
1108        port = support.bind_port(server)  # Reserve port but don't listen
1109        s = test_wrap_socket(socket.socket(socket.AF_INET),
1110                            cert_reqs=ssl.CERT_REQUIRED)
1111        self.addCleanup(s.close)
1112        rc = s.connect_ex((HOST, port))
1113        # Issue #19919: Windows machines or VMs hosted on Windows
1114        # machines sometimes return EWOULDBLOCK.
1115        errors = (
1116            errno.ECONNREFUSED, errno.EHOSTUNREACH, errno.ETIMEDOUT,
1117            errno.EWOULDBLOCK,
1118        )
1119        self.assertIn(rc, errors)
1120
1121
1122class ContextTests(unittest.TestCase):
1123
1124    def test_constructor(self):
1125        for protocol in PROTOCOLS:
1126            ssl.SSLContext(protocol)
1127        ctx = ssl.SSLContext()
1128        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1129        self.assertRaises(ValueError, ssl.SSLContext, -1)
1130        self.assertRaises(ValueError, ssl.SSLContext, 42)
1131
1132    def test_protocol(self):
1133        for proto in PROTOCOLS:
1134            ctx = ssl.SSLContext(proto)
1135            self.assertEqual(ctx.protocol, proto)
1136
1137    def test_ciphers(self):
1138        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1139        ctx.set_ciphers("ALL")
1140        ctx.set_ciphers("DEFAULT")
1141        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
1142            ctx.set_ciphers("^$:,;?*'dorothyx")
1143
1144    @unittest.skipUnless(PY_SSL_DEFAULT_CIPHERS == 1,
1145                         "Test applies only to Python default ciphers")
1146    def test_python_ciphers(self):
1147        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1148        ciphers = ctx.get_ciphers()
1149        for suite in ciphers:
1150            name = suite['name']
1151            self.assertNotIn("PSK", name)
1152            self.assertNotIn("SRP", name)
1153            self.assertNotIn("MD5", name)
1154            self.assertNotIn("RC4", name)
1155            self.assertNotIn("3DES", name)
1156
1157    @unittest.skipIf(ssl.OPENSSL_VERSION_INFO < (1, 0, 2, 0, 0), 'OpenSSL too old')
1158    def test_get_ciphers(self):
1159        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1160        ctx.set_ciphers('AESGCM')
1161        names = set(d['name'] for d in ctx.get_ciphers())
1162        self.assertIn('AES256-GCM-SHA384', names)
1163        self.assertIn('AES128-GCM-SHA256', names)
1164
1165    def test_options(self):
1166        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1167        # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value
1168        default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
1169        # SSLContext also enables these by default
1170        default |= (OP_NO_COMPRESSION | OP_CIPHER_SERVER_PREFERENCE |
1171                    OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE |
1172                    OP_ENABLE_MIDDLEBOX_COMPAT |
1173                    OP_IGNORE_UNEXPECTED_EOF)
1174        self.assertEqual(default, ctx.options)
1175        ctx.options |= ssl.OP_NO_TLSv1
1176        self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
1177        if can_clear_options():
1178            ctx.options = (ctx.options & ~ssl.OP_NO_TLSv1)
1179            self.assertEqual(default, ctx.options)
1180            ctx.options = 0
1181            # Ubuntu has OP_NO_SSLv3 forced on by default
1182            self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3)
1183        else:
1184            with self.assertRaises(ValueError):
1185                ctx.options = 0
1186
1187    def test_verify_mode_protocol(self):
1188        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1189        # Default value
1190        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1191        ctx.verify_mode = ssl.CERT_OPTIONAL
1192        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1193        ctx.verify_mode = ssl.CERT_REQUIRED
1194        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1195        ctx.verify_mode = ssl.CERT_NONE
1196        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1197        with self.assertRaises(TypeError):
1198            ctx.verify_mode = None
1199        with self.assertRaises(ValueError):
1200            ctx.verify_mode = 42
1201
1202        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1203        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1204        self.assertFalse(ctx.check_hostname)
1205
1206        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1207        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1208        self.assertTrue(ctx.check_hostname)
1209
1210    def test_hostname_checks_common_name(self):
1211        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1212        self.assertTrue(ctx.hostname_checks_common_name)
1213        if ssl.HAS_NEVER_CHECK_COMMON_NAME:
1214            ctx.hostname_checks_common_name = True
1215            self.assertTrue(ctx.hostname_checks_common_name)
1216            ctx.hostname_checks_common_name = False
1217            self.assertFalse(ctx.hostname_checks_common_name)
1218            ctx.hostname_checks_common_name = True
1219            self.assertTrue(ctx.hostname_checks_common_name)
1220        else:
1221            with self.assertRaises(AttributeError):
1222                ctx.hostname_checks_common_name = True
1223
1224    @requires_minimum_version
1225    @unittest.skipIf(IS_LIBRESSL, "see bpo-34001")
1226    def test_min_max_version(self):
1227        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1228        # OpenSSL default is MINIMUM_SUPPORTED, however some vendors like
1229        # Fedora override the setting to TLS 1.0.
1230        minimum_range = {
1231            # stock OpenSSL
1232            ssl.TLSVersion.MINIMUM_SUPPORTED,
1233            # Fedora 29 uses TLS 1.0 by default
1234            ssl.TLSVersion.TLSv1,
1235            # RHEL 8 uses TLS 1.2 by default
1236            ssl.TLSVersion.TLSv1_2
1237        }
1238        maximum_range = {
1239            # stock OpenSSL
1240            ssl.TLSVersion.MAXIMUM_SUPPORTED,
1241            # Fedora 32 uses TLS 1.3 by default
1242            ssl.TLSVersion.TLSv1_3
1243        }
1244
1245        self.assertIn(
1246            ctx.minimum_version, minimum_range
1247        )
1248        self.assertIn(
1249            ctx.maximum_version, maximum_range
1250        )
1251
1252        ctx.minimum_version = ssl.TLSVersion.TLSv1_1
1253        ctx.maximum_version = ssl.TLSVersion.TLSv1_2
1254        self.assertEqual(
1255            ctx.minimum_version, ssl.TLSVersion.TLSv1_1
1256        )
1257        self.assertEqual(
1258            ctx.maximum_version, ssl.TLSVersion.TLSv1_2
1259        )
1260
1261        ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1262        ctx.maximum_version = ssl.TLSVersion.TLSv1
1263        self.assertEqual(
1264            ctx.minimum_version, ssl.TLSVersion.MINIMUM_SUPPORTED
1265        )
1266        self.assertEqual(
1267            ctx.maximum_version, ssl.TLSVersion.TLSv1
1268        )
1269
1270        ctx.maximum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1271        self.assertEqual(
1272            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1273        )
1274
1275        ctx.maximum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1276        self.assertIn(
1277            ctx.maximum_version,
1278            {ssl.TLSVersion.TLSv1, ssl.TLSVersion.SSLv3}
1279        )
1280
1281        ctx.minimum_version = ssl.TLSVersion.MAXIMUM_SUPPORTED
1282        self.assertIn(
1283            ctx.minimum_version,
1284            {ssl.TLSVersion.TLSv1_2, ssl.TLSVersion.TLSv1_3}
1285        )
1286
1287        with self.assertRaises(ValueError):
1288            ctx.minimum_version = 42
1289
1290        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_1)
1291
1292        self.assertIn(
1293            ctx.minimum_version, minimum_range
1294        )
1295        self.assertEqual(
1296            ctx.maximum_version, ssl.TLSVersion.MAXIMUM_SUPPORTED
1297        )
1298        with self.assertRaises(ValueError):
1299            ctx.minimum_version = ssl.TLSVersion.MINIMUM_SUPPORTED
1300        with self.assertRaises(ValueError):
1301            ctx.maximum_version = ssl.TLSVersion.TLSv1
1302
1303
1304    @unittest.skipUnless(have_verify_flags(),
1305                         "verify_flags need OpenSSL > 0.9.8")
1306    def test_verify_flags(self):
1307        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1308        # default value
1309        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
1310        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf)
1311        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
1312        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_LEAF)
1313        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_CHAIN
1314        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_CHAIN)
1315        ctx.verify_flags = ssl.VERIFY_DEFAULT
1316        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT)
1317        # supports any value
1318        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT
1319        self.assertEqual(ctx.verify_flags,
1320                         ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT)
1321        with self.assertRaises(TypeError):
1322            ctx.verify_flags = None
1323
1324    def test_load_cert_chain(self):
1325        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1326        # Combined key and cert in a single file
1327        ctx.load_cert_chain(CERTFILE, keyfile=None)
1328        ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE)
1329        self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE)
1330        with self.assertRaises(OSError) as cm:
1331            ctx.load_cert_chain(NONEXISTINGCERT)
1332        self.assertEqual(cm.exception.errno, errno.ENOENT)
1333        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1334            ctx.load_cert_chain(BADCERT)
1335        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1336            ctx.load_cert_chain(EMPTYCERT)
1337        # Separate key and cert
1338        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1339        ctx.load_cert_chain(ONLYCERT, ONLYKEY)
1340        ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
1341        ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY)
1342        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1343            ctx.load_cert_chain(ONLYCERT)
1344        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1345            ctx.load_cert_chain(ONLYKEY)
1346        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1347            ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT)
1348        # Mismatching key and cert
1349        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1350        with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):
1351            ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY)
1352        # Password protected key and cert
1353        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD)
1354        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode())
1355        ctx.load_cert_chain(CERTFILE_PROTECTED,
1356                            password=bytearray(KEY_PASSWORD.encode()))
1357        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD)
1358        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode())
1359        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED,
1360                            bytearray(KEY_PASSWORD.encode()))
1361        with self.assertRaisesRegex(TypeError, "should be a string"):
1362            ctx.load_cert_chain(CERTFILE_PROTECTED, password=True)
1363        with self.assertRaises(ssl.SSLError):
1364            ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass")
1365        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1366            # openssl has a fixed limit on the password buffer.
1367            # PEM_BUFSIZE is generally set to 1kb.
1368            # Return a string larger than this.
1369            ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400)
1370        # Password callback
1371        def getpass_unicode():
1372            return KEY_PASSWORD
1373        def getpass_bytes():
1374            return KEY_PASSWORD.encode()
1375        def getpass_bytearray():
1376            return bytearray(KEY_PASSWORD.encode())
1377        def getpass_badpass():
1378            return "badpass"
1379        def getpass_huge():
1380            return b'a' * (1024 * 1024)
1381        def getpass_bad_type():
1382            return 9
1383        def getpass_exception():
1384            raise Exception('getpass error')
1385        class GetPassCallable:
1386            def __call__(self):
1387                return KEY_PASSWORD
1388            def getpass(self):
1389                return KEY_PASSWORD
1390        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode)
1391        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes)
1392        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray)
1393        ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable())
1394        ctx.load_cert_chain(CERTFILE_PROTECTED,
1395                            password=GetPassCallable().getpass)
1396        with self.assertRaises(ssl.SSLError):
1397            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass)
1398        with self.assertRaisesRegex(ValueError, "cannot be longer"):
1399            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge)
1400        with self.assertRaisesRegex(TypeError, "must return a string"):
1401            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type)
1402        with self.assertRaisesRegex(Exception, "getpass error"):
1403            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception)
1404        # Make sure the password function isn't called if it isn't needed
1405        ctx.load_cert_chain(CERTFILE, password=getpass_exception)
1406
1407    def test_load_verify_locations(self):
1408        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1409        ctx.load_verify_locations(CERTFILE)
1410        ctx.load_verify_locations(cafile=CERTFILE, capath=None)
1411        ctx.load_verify_locations(BYTES_CERTFILE)
1412        ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None)
1413        self.assertRaises(TypeError, ctx.load_verify_locations)
1414        self.assertRaises(TypeError, ctx.load_verify_locations, None, None, None)
1415        with self.assertRaises(OSError) as cm:
1416            ctx.load_verify_locations(NONEXISTINGCERT)
1417        self.assertEqual(cm.exception.errno, errno.ENOENT)
1418        with self.assertRaisesRegex(ssl.SSLError, "PEM lib"):
1419            ctx.load_verify_locations(BADCERT)
1420        ctx.load_verify_locations(CERTFILE, CAPATH)
1421        ctx.load_verify_locations(CERTFILE, capath=BYTES_CAPATH)
1422
1423        # Issue #10989: crash if the second argument type is invalid
1424        self.assertRaises(TypeError, ctx.load_verify_locations, None, True)
1425
1426    def test_load_verify_cadata(self):
1427        # test cadata
1428        with open(CAFILE_CACERT) as f:
1429            cacert_pem = f.read()
1430        cacert_der = ssl.PEM_cert_to_DER_cert(cacert_pem)
1431        with open(CAFILE_NEURONIO) as f:
1432            neuronio_pem = f.read()
1433        neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem)
1434
1435        # test PEM
1436        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1437        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0)
1438        ctx.load_verify_locations(cadata=cacert_pem)
1439        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1)
1440        ctx.load_verify_locations(cadata=neuronio_pem)
1441        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1442        # cert already in hash table
1443        ctx.load_verify_locations(cadata=neuronio_pem)
1444        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1445
1446        # combined
1447        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1448        combined = "\n".join((cacert_pem, neuronio_pem))
1449        ctx.load_verify_locations(cadata=combined)
1450        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1451
1452        # with junk around the certs
1453        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1454        combined = ["head", cacert_pem, "other", neuronio_pem, "again",
1455                    neuronio_pem, "tail"]
1456        ctx.load_verify_locations(cadata="\n".join(combined))
1457        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1458
1459        # test DER
1460        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1461        ctx.load_verify_locations(cadata=cacert_der)
1462        ctx.load_verify_locations(cadata=neuronio_der)
1463        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1464        # cert already in hash table
1465        ctx.load_verify_locations(cadata=cacert_der)
1466        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1467
1468        # combined
1469        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1470        combined = b"".join((cacert_der, neuronio_der))
1471        ctx.load_verify_locations(cadata=combined)
1472        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1473
1474        # error cases
1475        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1476        self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object)
1477
1478        with self.assertRaisesRegex(
1479            ssl.SSLError,
1480            "no start line: cadata does not contain a certificate"
1481        ):
1482            ctx.load_verify_locations(cadata="broken")
1483        with self.assertRaisesRegex(
1484            ssl.SSLError,
1485            "not enough data: cadata does not contain a certificate"
1486        ):
1487            ctx.load_verify_locations(cadata=b"broken")
1488
1489    def test_load_dh_params(self):
1490        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1491        ctx.load_dh_params(DHFILE)
1492        if os.name != 'nt':
1493            ctx.load_dh_params(BYTES_DHFILE)
1494        self.assertRaises(TypeError, ctx.load_dh_params)
1495        self.assertRaises(TypeError, ctx.load_dh_params, None)
1496        with self.assertRaises(FileNotFoundError) as cm:
1497            ctx.load_dh_params(NONEXISTINGCERT)
1498        self.assertEqual(cm.exception.errno, errno.ENOENT)
1499        with self.assertRaises(ssl.SSLError) as cm:
1500            ctx.load_dh_params(CERTFILE)
1501
1502    def test_session_stats(self):
1503        for proto in PROTOCOLS:
1504            ctx = ssl.SSLContext(proto)
1505            self.assertEqual(ctx.session_stats(), {
1506                'number': 0,
1507                'connect': 0,
1508                'connect_good': 0,
1509                'connect_renegotiate': 0,
1510                'accept': 0,
1511                'accept_good': 0,
1512                'accept_renegotiate': 0,
1513                'hits': 0,
1514                'misses': 0,
1515                'timeouts': 0,
1516                'cache_full': 0,
1517            })
1518
1519    def test_set_default_verify_paths(self):
1520        # There's not much we can do to test that it acts as expected,
1521        # so just check it doesn't crash or raise an exception.
1522        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1523        ctx.set_default_verify_paths()
1524
1525    @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build")
1526    def test_set_ecdh_curve(self):
1527        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1528        ctx.set_ecdh_curve("prime256v1")
1529        ctx.set_ecdh_curve(b"prime256v1")
1530        self.assertRaises(TypeError, ctx.set_ecdh_curve)
1531        self.assertRaises(TypeError, ctx.set_ecdh_curve, None)
1532        self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
1533        self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
1534
1535    @needs_sni
1536    def test_sni_callback(self):
1537        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1538
1539        # set_servername_callback expects a callable, or None
1540        self.assertRaises(TypeError, ctx.set_servername_callback)
1541        self.assertRaises(TypeError, ctx.set_servername_callback, 4)
1542        self.assertRaises(TypeError, ctx.set_servername_callback, "")
1543        self.assertRaises(TypeError, ctx.set_servername_callback, ctx)
1544
1545        def dummycallback(sock, servername, ctx):
1546            pass
1547        ctx.set_servername_callback(None)
1548        ctx.set_servername_callback(dummycallback)
1549
1550    @needs_sni
1551    def test_sni_callback_refcycle(self):
1552        # Reference cycles through the servername callback are detected
1553        # and cleared.
1554        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1555        def dummycallback(sock, servername, ctx, cycle=ctx):
1556            pass
1557        ctx.set_servername_callback(dummycallback)
1558        wr = weakref.ref(ctx)
1559        del ctx, dummycallback
1560        gc.collect()
1561        self.assertIs(wr(), None)
1562
1563    def test_cert_store_stats(self):
1564        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1565        self.assertEqual(ctx.cert_store_stats(),
1566            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1567        ctx.load_cert_chain(CERTFILE)
1568        self.assertEqual(ctx.cert_store_stats(),
1569            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1570        ctx.load_verify_locations(CERTFILE)
1571        self.assertEqual(ctx.cert_store_stats(),
1572            {'x509_ca': 0, 'crl': 0, 'x509': 1})
1573        ctx.load_verify_locations(CAFILE_CACERT)
1574        self.assertEqual(ctx.cert_store_stats(),
1575            {'x509_ca': 1, 'crl': 0, 'x509': 2})
1576
1577    def test_get_ca_certs(self):
1578        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1579        self.assertEqual(ctx.get_ca_certs(), [])
1580        # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE
1581        ctx.load_verify_locations(CERTFILE)
1582        self.assertEqual(ctx.get_ca_certs(), [])
1583        # but CAFILE_CACERT is a CA cert
1584        ctx.load_verify_locations(CAFILE_CACERT)
1585        self.assertEqual(ctx.get_ca_certs(),
1586            [{'issuer': ((('organizationName', 'Root CA'),),
1587                         (('organizationalUnitName', 'http://www.cacert.org'),),
1588                         (('commonName', 'CA Cert Signing Authority'),),
1589                         (('emailAddress', 'support@cacert.org'),)),
1590              'notAfter': asn1time('Mar 29 12:29:49 2033 GMT'),
1591              'notBefore': asn1time('Mar 30 12:29:49 2003 GMT'),
1592              'serialNumber': '00',
1593              'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',),
1594              'subject': ((('organizationName', 'Root CA'),),
1595                          (('organizationalUnitName', 'http://www.cacert.org'),),
1596                          (('commonName', 'CA Cert Signing Authority'),),
1597                          (('emailAddress', 'support@cacert.org'),)),
1598              'version': 3}])
1599
1600        with open(CAFILE_CACERT) as f:
1601            pem = f.read()
1602        der = ssl.PEM_cert_to_DER_cert(pem)
1603        self.assertEqual(ctx.get_ca_certs(True), [der])
1604
1605    def test_load_default_certs(self):
1606        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1607        ctx.load_default_certs()
1608
1609        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1610        ctx.load_default_certs(ssl.Purpose.SERVER_AUTH)
1611        ctx.load_default_certs()
1612
1613        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1614        ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH)
1615
1616        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1617        self.assertRaises(TypeError, ctx.load_default_certs, None)
1618        self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
1619
1620    @unittest.skipIf(sys.platform == "win32", "not-Windows specific")
1621    @unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars")
1622    def test_load_default_certs_env(self):
1623        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1624        with support.EnvironmentVarGuard() as env:
1625            env["SSL_CERT_DIR"] = CAPATH
1626            env["SSL_CERT_FILE"] = CERTFILE
1627            ctx.load_default_certs()
1628            self.assertEqual(ctx.cert_store_stats(), {"crl": 0, "x509": 1, "x509_ca": 0})
1629
1630    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
1631    @unittest.skipIf(hasattr(sys, "gettotalrefcount"), "Debug build does not share environment between CRTs")
1632    def test_load_default_certs_env_windows(self):
1633        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1634        ctx.load_default_certs()
1635        stats = ctx.cert_store_stats()
1636
1637        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1638        with support.EnvironmentVarGuard() as env:
1639            env["SSL_CERT_DIR"] = CAPATH
1640            env["SSL_CERT_FILE"] = CERTFILE
1641            ctx.load_default_certs()
1642            stats["x509"] += 1
1643            self.assertEqual(ctx.cert_store_stats(), stats)
1644
1645    def _assert_context_options(self, ctx):
1646        self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2)
1647        if OP_NO_COMPRESSION != 0:
1648            self.assertEqual(ctx.options & OP_NO_COMPRESSION,
1649                             OP_NO_COMPRESSION)
1650        if OP_SINGLE_DH_USE != 0:
1651            self.assertEqual(ctx.options & OP_SINGLE_DH_USE,
1652                             OP_SINGLE_DH_USE)
1653        if OP_SINGLE_ECDH_USE != 0:
1654            self.assertEqual(ctx.options & OP_SINGLE_ECDH_USE,
1655                             OP_SINGLE_ECDH_USE)
1656        if OP_CIPHER_SERVER_PREFERENCE != 0:
1657            self.assertEqual(ctx.options & OP_CIPHER_SERVER_PREFERENCE,
1658                             OP_CIPHER_SERVER_PREFERENCE)
1659
1660    def test_create_default_context(self):
1661        ctx = ssl.create_default_context()
1662
1663        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1664        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1665        self.assertTrue(ctx.check_hostname)
1666        self._assert_context_options(ctx)
1667
1668        with open(SIGNING_CA) as f:
1669            cadata = f.read()
1670        ctx = ssl.create_default_context(cafile=SIGNING_CA, capath=CAPATH,
1671                                         cadata=cadata)
1672        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1673        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1674        self._assert_context_options(ctx)
1675
1676        ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
1677        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1678        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1679        self._assert_context_options(ctx)
1680
1681    def test__create_stdlib_context(self):
1682        ctx = ssl._create_stdlib_context()
1683        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1684        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1685        self.assertFalse(ctx.check_hostname)
1686        self._assert_context_options(ctx)
1687
1688        ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1)
1689        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1690        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1691        self._assert_context_options(ctx)
1692
1693        ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1,
1694                                         cert_reqs=ssl.CERT_REQUIRED,
1695                                         check_hostname=True)
1696        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1697        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1698        self.assertTrue(ctx.check_hostname)
1699        self._assert_context_options(ctx)
1700
1701        ctx = ssl._create_stdlib_context(purpose=ssl.Purpose.CLIENT_AUTH)
1702        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLS)
1703        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1704        self._assert_context_options(ctx)
1705
1706    def test_check_hostname(self):
1707        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
1708        self.assertFalse(ctx.check_hostname)
1709        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1710
1711        # Auto set CERT_REQUIRED
1712        ctx.check_hostname = True
1713        self.assertTrue(ctx.check_hostname)
1714        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1715        ctx.check_hostname = False
1716        ctx.verify_mode = ssl.CERT_REQUIRED
1717        self.assertFalse(ctx.check_hostname)
1718        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1719
1720        # Changing verify_mode does not affect check_hostname
1721        ctx.check_hostname = False
1722        ctx.verify_mode = ssl.CERT_NONE
1723        ctx.check_hostname = False
1724        self.assertFalse(ctx.check_hostname)
1725        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1726        # Auto set
1727        ctx.check_hostname = True
1728        self.assertTrue(ctx.check_hostname)
1729        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1730
1731        ctx.check_hostname = False
1732        ctx.verify_mode = ssl.CERT_OPTIONAL
1733        ctx.check_hostname = False
1734        self.assertFalse(ctx.check_hostname)
1735        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1736        # keep CERT_OPTIONAL
1737        ctx.check_hostname = True
1738        self.assertTrue(ctx.check_hostname)
1739        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
1740
1741        # Cannot set CERT_NONE with check_hostname enabled
1742        with self.assertRaises(ValueError):
1743            ctx.verify_mode = ssl.CERT_NONE
1744        ctx.check_hostname = False
1745        self.assertFalse(ctx.check_hostname)
1746        ctx.verify_mode = ssl.CERT_NONE
1747        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1748
1749    def test_context_client_server(self):
1750        # PROTOCOL_TLS_CLIENT has sane defaults
1751        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1752        self.assertTrue(ctx.check_hostname)
1753        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1754
1755        # PROTOCOL_TLS_SERVER has different but also sane defaults
1756        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1757        self.assertFalse(ctx.check_hostname)
1758        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1759
1760    def test_context_custom_class(self):
1761        class MySSLSocket(ssl.SSLSocket):
1762            pass
1763
1764        class MySSLObject(ssl.SSLObject):
1765            pass
1766
1767        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1768        ctx.sslsocket_class = MySSLSocket
1769        ctx.sslobject_class = MySSLObject
1770
1771        with ctx.wrap_socket(socket.socket(), server_side=True) as sock:
1772            self.assertIsInstance(sock, MySSLSocket)
1773        obj = ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO())
1774        self.assertIsInstance(obj, MySSLObject)
1775
1776    @unittest.skipUnless(IS_OPENSSL_1_1_1, "Test requires OpenSSL 1.1.1")
1777    def test_num_tickest(self):
1778        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
1779        self.assertEqual(ctx.num_tickets, 2)
1780        ctx.num_tickets = 1
1781        self.assertEqual(ctx.num_tickets, 1)
1782        ctx.num_tickets = 0
1783        self.assertEqual(ctx.num_tickets, 0)
1784        with self.assertRaises(ValueError):
1785            ctx.num_tickets = -1
1786        with self.assertRaises(TypeError):
1787            ctx.num_tickets = None
1788
1789        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1790        self.assertEqual(ctx.num_tickets, 2)
1791        with self.assertRaises(ValueError):
1792            ctx.num_tickets = 1
1793
1794
1795class SSLErrorTests(unittest.TestCase):
1796
1797    def test_str(self):
1798        # The str() of a SSLError doesn't include the errno
1799        e = ssl.SSLError(1, "foo")
1800        self.assertEqual(str(e), "foo")
1801        self.assertEqual(e.errno, 1)
1802        # Same for a subclass
1803        e = ssl.SSLZeroReturnError(1, "foo")
1804        self.assertEqual(str(e), "foo")
1805        self.assertEqual(e.errno, 1)
1806
1807    def test_lib_reason(self):
1808        # Test the library and reason attributes
1809        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1810        with self.assertRaises(ssl.SSLError) as cm:
1811            ctx.load_dh_params(CERTFILE)
1812        self.assertEqual(cm.exception.library, 'PEM')
1813        self.assertEqual(cm.exception.reason, 'NO_START_LINE')
1814        s = str(cm.exception)
1815        self.assertTrue(s.startswith("[PEM: NO_START_LINE] no start line"), s)
1816
1817    def test_subclass(self):
1818        # Check that the appropriate SSLError subclass is raised
1819        # (this only tests one of them)
1820        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
1821        ctx.check_hostname = False
1822        ctx.verify_mode = ssl.CERT_NONE
1823        with socket.create_server(("127.0.0.1", 0)) as s:
1824            c = socket.create_connection(s.getsockname())
1825            c.setblocking(False)
1826            with ctx.wrap_socket(c, False, do_handshake_on_connect=False) as c:
1827                with self.assertRaises(ssl.SSLWantReadError) as cm:
1828                    c.do_handshake()
1829                s = str(cm.exception)
1830                self.assertTrue(s.startswith("The operation did not complete (read)"), s)
1831                # For compatibility
1832                self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
1833
1834
1835    def test_bad_server_hostname(self):
1836        ctx = ssl.create_default_context()
1837        with self.assertRaises(ValueError):
1838            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1839                         server_hostname="")
1840        with self.assertRaises(ValueError):
1841            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1842                         server_hostname=".example.org")
1843        with self.assertRaises(TypeError):
1844            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
1845                         server_hostname="example.org\x00evil.com")
1846
1847
1848class MemoryBIOTests(unittest.TestCase):
1849
1850    def test_read_write(self):
1851        bio = ssl.MemoryBIO()
1852        bio.write(b'foo')
1853        self.assertEqual(bio.read(), b'foo')
1854        self.assertEqual(bio.read(), b'')
1855        bio.write(b'foo')
1856        bio.write(b'bar')
1857        self.assertEqual(bio.read(), b'foobar')
1858        self.assertEqual(bio.read(), b'')
1859        bio.write(b'baz')
1860        self.assertEqual(bio.read(2), b'ba')
1861        self.assertEqual(bio.read(1), b'z')
1862        self.assertEqual(bio.read(1), b'')
1863
1864    def test_eof(self):
1865        bio = ssl.MemoryBIO()
1866        self.assertFalse(bio.eof)
1867        self.assertEqual(bio.read(), b'')
1868        self.assertFalse(bio.eof)
1869        bio.write(b'foo')
1870        self.assertFalse(bio.eof)
1871        bio.write_eof()
1872        self.assertFalse(bio.eof)
1873        self.assertEqual(bio.read(2), b'fo')
1874        self.assertFalse(bio.eof)
1875        self.assertEqual(bio.read(1), b'o')
1876        self.assertTrue(bio.eof)
1877        self.assertEqual(bio.read(), b'')
1878        self.assertTrue(bio.eof)
1879
1880    def test_pending(self):
1881        bio = ssl.MemoryBIO()
1882        self.assertEqual(bio.pending, 0)
1883        bio.write(b'foo')
1884        self.assertEqual(bio.pending, 3)
1885        for i in range(3):
1886            bio.read(1)
1887            self.assertEqual(bio.pending, 3-i-1)
1888        for i in range(3):
1889            bio.write(b'x')
1890            self.assertEqual(bio.pending, i+1)
1891        bio.read()
1892        self.assertEqual(bio.pending, 0)
1893
1894    def test_buffer_types(self):
1895        bio = ssl.MemoryBIO()
1896        bio.write(b'foo')
1897        self.assertEqual(bio.read(), b'foo')
1898        bio.write(bytearray(b'bar'))
1899        self.assertEqual(bio.read(), b'bar')
1900        bio.write(memoryview(b'baz'))
1901        self.assertEqual(bio.read(), b'baz')
1902
1903    def test_error_types(self):
1904        bio = ssl.MemoryBIO()
1905        self.assertRaises(TypeError, bio.write, 'foo')
1906        self.assertRaises(TypeError, bio.write, None)
1907        self.assertRaises(TypeError, bio.write, True)
1908        self.assertRaises(TypeError, bio.write, 1)
1909
1910
1911class SSLObjectTests(unittest.TestCase):
1912    def test_private_init(self):
1913        bio = ssl.MemoryBIO()
1914        with self.assertRaisesRegex(TypeError, "public constructor"):
1915            ssl.SSLObject(bio, bio)
1916
1917    def test_unwrap(self):
1918        client_ctx, server_ctx, hostname = testing_context()
1919        c_in = ssl.MemoryBIO()
1920        c_out = ssl.MemoryBIO()
1921        s_in = ssl.MemoryBIO()
1922        s_out = ssl.MemoryBIO()
1923        client = client_ctx.wrap_bio(c_in, c_out, server_hostname=hostname)
1924        server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
1925
1926        # Loop on the handshake for a bit to get it settled
1927        for _ in range(5):
1928            try:
1929                client.do_handshake()
1930            except ssl.SSLWantReadError:
1931                pass
1932            if c_out.pending:
1933                s_in.write(c_out.read())
1934            try:
1935                server.do_handshake()
1936            except ssl.SSLWantReadError:
1937                pass
1938            if s_out.pending:
1939                c_in.write(s_out.read())
1940        # Now the handshakes should be complete (don't raise WantReadError)
1941        client.do_handshake()
1942        server.do_handshake()
1943
1944        # Now if we unwrap one side unilaterally, it should send close-notify
1945        # and raise WantReadError:
1946        with self.assertRaises(ssl.SSLWantReadError):
1947            client.unwrap()
1948
1949        # But server.unwrap() does not raise, because it reads the client's
1950        # close-notify:
1951        s_in.write(c_out.read())
1952        server.unwrap()
1953
1954        # And now that the client gets the server's close-notify, it doesn't
1955        # raise either.
1956        c_in.write(s_out.read())
1957        client.unwrap()
1958
1959class SimpleBackgroundTests(unittest.TestCase):
1960    """Tests that connect to a simple server running in the background"""
1961
1962    def setUp(self):
1963        server = ThreadedEchoServer(SIGNED_CERTFILE)
1964        self.server_addr = (HOST, server.port)
1965        server.__enter__()
1966        self.addCleanup(server.__exit__, None, None, None)
1967
1968    def test_connect(self):
1969        with test_wrap_socket(socket.socket(socket.AF_INET),
1970                            cert_reqs=ssl.CERT_NONE) as s:
1971            s.connect(self.server_addr)
1972            self.assertEqual({}, s.getpeercert())
1973            self.assertFalse(s.server_side)
1974
1975        # this should succeed because we specify the root cert
1976        with test_wrap_socket(socket.socket(socket.AF_INET),
1977                            cert_reqs=ssl.CERT_REQUIRED,
1978                            ca_certs=SIGNING_CA) as s:
1979            s.connect(self.server_addr)
1980            self.assertTrue(s.getpeercert())
1981            self.assertFalse(s.server_side)
1982
1983    def test_connect_fail(self):
1984        # This should fail because we have no verification certs. Connection
1985        # failure crashes ThreadedEchoServer, so run this in an independent
1986        # test method.
1987        s = test_wrap_socket(socket.socket(socket.AF_INET),
1988                            cert_reqs=ssl.CERT_REQUIRED)
1989        self.addCleanup(s.close)
1990        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
1991                               s.connect, self.server_addr)
1992
1993    def test_connect_ex(self):
1994        # Issue #11326: check connect_ex() implementation
1995        s = test_wrap_socket(socket.socket(socket.AF_INET),
1996                            cert_reqs=ssl.CERT_REQUIRED,
1997                            ca_certs=SIGNING_CA)
1998        self.addCleanup(s.close)
1999        self.assertEqual(0, s.connect_ex(self.server_addr))
2000        self.assertTrue(s.getpeercert())
2001
2002    def test_non_blocking_connect_ex(self):
2003        # Issue #11326: non-blocking connect_ex() should allow handshake
2004        # to proceed after the socket gets ready.
2005        s = test_wrap_socket(socket.socket(socket.AF_INET),
2006                            cert_reqs=ssl.CERT_REQUIRED,
2007                            ca_certs=SIGNING_CA,
2008                            do_handshake_on_connect=False)
2009        self.addCleanup(s.close)
2010        s.setblocking(False)
2011        rc = s.connect_ex(self.server_addr)
2012        # EWOULDBLOCK under Windows, EINPROGRESS elsewhere
2013        self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK))
2014        # Wait for connect to finish
2015        select.select([], [s], [], 5.0)
2016        # Non-blocking handshake
2017        while True:
2018            try:
2019                s.do_handshake()
2020                break
2021            except ssl.SSLWantReadError:
2022                select.select([s], [], [], 5.0)
2023            except ssl.SSLWantWriteError:
2024                select.select([], [s], [], 5.0)
2025        # SSL established
2026        self.assertTrue(s.getpeercert())
2027
2028    def test_connect_with_context(self):
2029        # Same as test_connect, but with a separately created context
2030        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2031        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2032            s.connect(self.server_addr)
2033            self.assertEqual({}, s.getpeercert())
2034        # Same with a server hostname
2035        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2036                            server_hostname="dummy") as s:
2037            s.connect(self.server_addr)
2038        ctx.verify_mode = ssl.CERT_REQUIRED
2039        # This should succeed because we specify the root cert
2040        ctx.load_verify_locations(SIGNING_CA)
2041        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2042            s.connect(self.server_addr)
2043            cert = s.getpeercert()
2044            self.assertTrue(cert)
2045
2046    def test_connect_with_context_fail(self):
2047        # This should fail because we have no verification certs. Connection
2048        # failure crashes ThreadedEchoServer, so run this in an independent
2049        # test method.
2050        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2051        ctx.verify_mode = ssl.CERT_REQUIRED
2052        s = ctx.wrap_socket(socket.socket(socket.AF_INET))
2053        self.addCleanup(s.close)
2054        self.assertRaisesRegex(ssl.SSLError, "certificate verify failed",
2055                                s.connect, self.server_addr)
2056
2057    def test_connect_capath(self):
2058        # Verify server certificates using the `capath` argument
2059        # NOTE: the subject hashing algorithm has been changed between
2060        # OpenSSL 0.9.8n and 1.0.0, as a result the capath directory must
2061        # contain both versions of each certificate (same content, different
2062        # filename) for this test to be portable across OpenSSL releases.
2063        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2064        ctx.verify_mode = ssl.CERT_REQUIRED
2065        ctx.load_verify_locations(capath=CAPATH)
2066        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2067            s.connect(self.server_addr)
2068            cert = s.getpeercert()
2069            self.assertTrue(cert)
2070
2071        # Same with a bytes `capath` argument
2072        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2073        ctx.verify_mode = ssl.CERT_REQUIRED
2074        ctx.load_verify_locations(capath=BYTES_CAPATH)
2075        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2076            s.connect(self.server_addr)
2077            cert = s.getpeercert()
2078            self.assertTrue(cert)
2079
2080    def test_connect_cadata(self):
2081        with open(SIGNING_CA) as f:
2082            pem = f.read()
2083        der = ssl.PEM_cert_to_DER_cert(pem)
2084        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2085        ctx.verify_mode = ssl.CERT_REQUIRED
2086        ctx.load_verify_locations(cadata=pem)
2087        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2088            s.connect(self.server_addr)
2089            cert = s.getpeercert()
2090            self.assertTrue(cert)
2091
2092        # same with DER
2093        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2094        ctx.verify_mode = ssl.CERT_REQUIRED
2095        ctx.load_verify_locations(cadata=der)
2096        with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
2097            s.connect(self.server_addr)
2098            cert = s.getpeercert()
2099            self.assertTrue(cert)
2100
2101    @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows")
2102    def test_makefile_close(self):
2103        # Issue #5238: creating a file-like object with makefile() shouldn't
2104        # delay closing the underlying "real socket" (here tested with its
2105        # file descriptor, hence skipping the test under Windows).
2106        ss = test_wrap_socket(socket.socket(socket.AF_INET))
2107        ss.connect(self.server_addr)
2108        fd = ss.fileno()
2109        f = ss.makefile()
2110        f.close()
2111        # The fd is still open
2112        os.read(fd, 0)
2113        # Closing the SSL socket should close the fd too
2114        ss.close()
2115        gc.collect()
2116        with self.assertRaises(OSError) as e:
2117            os.read(fd, 0)
2118        self.assertEqual(e.exception.errno, errno.EBADF)
2119
2120    def test_non_blocking_handshake(self):
2121        s = socket.socket(socket.AF_INET)
2122        s.connect(self.server_addr)
2123        s.setblocking(False)
2124        s = test_wrap_socket(s,
2125                            cert_reqs=ssl.CERT_NONE,
2126                            do_handshake_on_connect=False)
2127        self.addCleanup(s.close)
2128        count = 0
2129        while True:
2130            try:
2131                count += 1
2132                s.do_handshake()
2133                break
2134            except ssl.SSLWantReadError:
2135                select.select([s], [], [])
2136            except ssl.SSLWantWriteError:
2137                select.select([], [s], [])
2138        if support.verbose:
2139            sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
2140
2141    def test_get_server_certificate(self):
2142        _test_get_server_certificate(self, *self.server_addr, cert=SIGNING_CA)
2143
2144    def test_get_server_certificate_fail(self):
2145        # Connection failure crashes ThreadedEchoServer, so run this in an
2146        # independent test method
2147        _test_get_server_certificate_fail(self, *self.server_addr)
2148
2149    def test_ciphers(self):
2150        with test_wrap_socket(socket.socket(socket.AF_INET),
2151                             cert_reqs=ssl.CERT_NONE, ciphers="ALL") as s:
2152            s.connect(self.server_addr)
2153        with test_wrap_socket(socket.socket(socket.AF_INET),
2154                             cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT") as s:
2155            s.connect(self.server_addr)
2156        # Error checking can happen at instantiation or when connecting
2157        with self.assertRaisesRegex(ssl.SSLError, "No cipher can be selected"):
2158            with socket.socket(socket.AF_INET) as sock:
2159                s = test_wrap_socket(sock,
2160                                    cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
2161                s.connect(self.server_addr)
2162
2163    def test_get_ca_certs_capath(self):
2164        # capath certs are loaded on request
2165        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2166        ctx.load_verify_locations(capath=CAPATH)
2167        self.assertEqual(ctx.get_ca_certs(), [])
2168        with ctx.wrap_socket(socket.socket(socket.AF_INET),
2169                             server_hostname='localhost') as s:
2170            s.connect(self.server_addr)
2171            cert = s.getpeercert()
2172            self.assertTrue(cert)
2173        self.assertEqual(len(ctx.get_ca_certs()), 1)
2174
2175    @needs_sni
2176    def test_context_setget(self):
2177        # Check that the context of a connected socket can be replaced.
2178        ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2179        ctx1.load_verify_locations(capath=CAPATH)
2180        ctx2 = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2181        ctx2.load_verify_locations(capath=CAPATH)
2182        s = socket.socket(socket.AF_INET)
2183        with ctx1.wrap_socket(s, server_hostname='localhost') as ss:
2184            ss.connect(self.server_addr)
2185            self.assertIs(ss.context, ctx1)
2186            self.assertIs(ss._sslobj.context, ctx1)
2187            ss.context = ctx2
2188            self.assertIs(ss.context, ctx2)
2189            self.assertIs(ss._sslobj.context, ctx2)
2190
2191    def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs):
2192        # A simple IO loop. Call func(*args) depending on the error we get
2193        # (WANT_READ or WANT_WRITE) move data between the socket and the BIOs.
2194        timeout = kwargs.get('timeout', 10)
2195        deadline = time.monotonic() + timeout
2196        count = 0
2197        while True:
2198            if time.monotonic() > deadline:
2199                self.fail("timeout")
2200            errno = None
2201            count += 1
2202            try:
2203                ret = func(*args)
2204            except ssl.SSLError as e:
2205                if e.errno not in (ssl.SSL_ERROR_WANT_READ,
2206                                   ssl.SSL_ERROR_WANT_WRITE):
2207                    raise
2208                errno = e.errno
2209            # Get any data from the outgoing BIO irrespective of any error, and
2210            # send it to the socket.
2211            buf = outgoing.read()
2212            sock.sendall(buf)
2213            # If there's no error, we're done. For WANT_READ, we need to get
2214            # data from the socket and put it in the incoming BIO.
2215            if errno is None:
2216                break
2217            elif errno == ssl.SSL_ERROR_WANT_READ:
2218                buf = sock.recv(32768)
2219                if buf:
2220                    incoming.write(buf)
2221                else:
2222                    incoming.write_eof()
2223        if support.verbose:
2224            sys.stdout.write("Needed %d calls to complete %s().\n"
2225                             % (count, func.__name__))
2226        return ret
2227
2228    def test_bio_handshake(self):
2229        sock = socket.socket(socket.AF_INET)
2230        self.addCleanup(sock.close)
2231        sock.connect(self.server_addr)
2232        incoming = ssl.MemoryBIO()
2233        outgoing = ssl.MemoryBIO()
2234        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
2235        self.assertTrue(ctx.check_hostname)
2236        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
2237        ctx.load_verify_locations(SIGNING_CA)
2238        sslobj = ctx.wrap_bio(incoming, outgoing, False,
2239                              SIGNED_CERTFILE_HOSTNAME)
2240        self.assertIs(sslobj._sslobj.owner, sslobj)
2241        self.assertIsNone(sslobj.cipher())
2242        self.assertIsNone(sslobj.version())
2243        self.assertIsNotNone(sslobj.shared_ciphers())
2244        self.assertRaises(ValueError, sslobj.getpeercert)
2245        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2246            self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
2247        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2248        self.assertTrue(sslobj.cipher())
2249        self.assertIsNotNone(sslobj.shared_ciphers())
2250        self.assertIsNotNone(sslobj.version())
2251        self.assertTrue(sslobj.getpeercert())
2252        if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
2253            self.assertTrue(sslobj.get_channel_binding('tls-unique'))
2254        try:
2255            self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2256        except ssl.SSLSyscallError:
2257            # If the server shuts down the TCP connection without sending a
2258            # secure shutdown message, this is reported as SSL_ERROR_SYSCALL
2259            pass
2260        self.assertRaises(ssl.SSLError, sslobj.write, b'foo')
2261
2262    def test_bio_read_write_data(self):
2263        sock = socket.socket(socket.AF_INET)
2264        self.addCleanup(sock.close)
2265        sock.connect(self.server_addr)
2266        incoming = ssl.MemoryBIO()
2267        outgoing = ssl.MemoryBIO()
2268        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
2269        ctx.verify_mode = ssl.CERT_NONE
2270        sslobj = ctx.wrap_bio(incoming, outgoing, False)
2271        self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
2272        req = b'FOO\n'
2273        self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
2274        buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
2275        self.assertEqual(buf, b'foo\n')
2276        self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
2277
2278
2279class NetworkedTests(unittest.TestCase):
2280
2281    def test_timeout_connect_ex(self):
2282        # Issue #12065: on a timeout, connect_ex() should return the original
2283        # errno (mimicking the behaviour of non-SSL sockets).
2284        with support.transient_internet(REMOTE_HOST):
2285            s = test_wrap_socket(socket.socket(socket.AF_INET),
2286                                cert_reqs=ssl.CERT_REQUIRED,
2287                                do_handshake_on_connect=False)
2288            self.addCleanup(s.close)
2289            s.settimeout(0.0000001)
2290            rc = s.connect_ex((REMOTE_HOST, 443))
2291            if rc == 0:
2292                self.skipTest("REMOTE_HOST responded too quickly")
2293            elif rc == errno.ENETUNREACH:
2294                self.skipTest("Network unreachable.")
2295            self.assertIn(rc, (errno.EAGAIN, errno.EWOULDBLOCK))
2296
2297    @unittest.skipUnless(support.IPV6_ENABLED, 'Needs IPv6')
2298    def test_get_server_certificate_ipv6(self):
2299        with support.transient_internet('ipv6.google.com'):
2300            _test_get_server_certificate(self, 'ipv6.google.com', 443)
2301            _test_get_server_certificate_fail(self, 'ipv6.google.com', 443)
2302
2303
2304def _test_get_server_certificate(test, host, port, cert=None):
2305    pem = ssl.get_server_certificate((host, port))
2306    if not pem:
2307        test.fail("No server certificate on %s:%s!" % (host, port))
2308
2309    pem = ssl.get_server_certificate((host, port), ca_certs=cert)
2310    if not pem:
2311        test.fail("No server certificate on %s:%s!" % (host, port))
2312    if support.verbose:
2313        sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))
2314
2315def _test_get_server_certificate_fail(test, host, port):
2316    try:
2317        pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE)
2318    except ssl.SSLError as x:
2319        #should fail
2320        if support.verbose:
2321            sys.stdout.write("%s\n" % x)
2322    else:
2323        test.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
2324
2325
2326from test.ssl_servers import make_https_server
2327
2328class ThreadedEchoServer(threading.Thread):
2329
2330    class ConnectionHandler(threading.Thread):
2331
2332        """A mildly complicated class, because we want it to work both
2333        with and without the SSL wrapper around the socket connection, so
2334        that we can test the STARTTLS functionality."""
2335
2336        def __init__(self, server, connsock, addr):
2337            self.server = server
2338            self.running = False
2339            self.sock = connsock
2340            self.addr = addr
2341            self.sock.setblocking(1)
2342            self.sslconn = None
2343            threading.Thread.__init__(self)
2344            self.daemon = True
2345
2346        def wrap_conn(self):
2347            try:
2348                self.sslconn = self.server.context.wrap_socket(
2349                    self.sock, server_side=True)
2350                self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
2351                self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
2352            except (ConnectionResetError, BrokenPipeError, ConnectionAbortedError) as e:
2353                # We treat ConnectionResetError as though it were an
2354                # SSLError - OpenSSL on Ubuntu abruptly closes the
2355                # connection when asked to use an unsupported protocol.
2356                #
2357                # BrokenPipeError is raised in TLS 1.3 mode, when OpenSSL
2358                # tries to send session tickets after handshake.
2359                # https://github.com/openssl/openssl/issues/6342
2360                #
2361                # ConnectionAbortedError is raised in TLS 1.3 mode, when OpenSSL
2362                # tries to send session tickets after handshake when using WinSock.
2363                self.server.conn_errors.append(str(e))
2364                if self.server.chatty:
2365                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2366                self.running = False
2367                self.close()
2368                return False
2369            except (ssl.SSLError, OSError) as e:
2370                # OSError may occur with wrong protocols, e.g. both
2371                # sides use PROTOCOL_TLS_SERVER.
2372                #
2373                # XXX Various errors can have happened here, for example
2374                # a mismatching protocol version, an invalid certificate,
2375                # or a low-level bug. This should be made more discriminating.
2376                #
2377                # bpo-31323: Store the exception as string to prevent
2378                # a reference leak: server -> conn_errors -> exception
2379                # -> traceback -> self (ConnectionHandler) -> server
2380                self.server.conn_errors.append(str(e))
2381                if self.server.chatty:
2382                    handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
2383
2384                # bpo-44229, bpo-43855, bpo-44237, and bpo-33450:
2385                # Ignore spurious EPROTOTYPE returned by write() on macOS.
2386                # See also http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
2387                if e.errno != errno.EPROTOTYPE and sys.platform != "darwin":
2388                    self.running = False
2389                    self.server.stop()
2390                    self.close()
2391                return False
2392            else:
2393                self.server.shared_ciphers.append(self.sslconn.shared_ciphers())
2394                if self.server.context.verify_mode == ssl.CERT_REQUIRED:
2395                    cert = self.sslconn.getpeercert()
2396                    if support.verbose and self.server.chatty:
2397                        sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
2398                    cert_binary = self.sslconn.getpeercert(True)
2399                    if support.verbose and self.server.chatty:
2400                        sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
2401                cipher = self.sslconn.cipher()
2402                if support.verbose and self.server.chatty:
2403                    sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
2404                    sys.stdout.write(" server: selected protocol is now "
2405                            + str(self.sslconn.selected_npn_protocol()) + "\n")
2406                return True
2407
2408        def read(self):
2409            if self.sslconn:
2410                return self.sslconn.read()
2411            else:
2412                return self.sock.recv(1024)
2413
2414        def write(self, bytes):
2415            if self.sslconn:
2416                return self.sslconn.write(bytes)
2417            else:
2418                return self.sock.send(bytes)
2419
2420        def close(self):
2421            if self.sslconn:
2422                self.sslconn.close()
2423            else:
2424                self.sock.close()
2425
2426        def run(self):
2427            self.running = True
2428            if not self.server.starttls_server:
2429                if not self.wrap_conn():
2430                    return
2431            while self.running:
2432                try:
2433                    msg = self.read()
2434                    stripped = msg.strip()
2435                    if not stripped:
2436                        # eof, so quit this handler
2437                        self.running = False
2438                        try:
2439                            self.sock = self.sslconn.unwrap()
2440                        except OSError:
2441                            # Many tests shut the TCP connection down
2442                            # without an SSL shutdown. This causes
2443                            # unwrap() to raise OSError with errno=0!
2444                            pass
2445                        else:
2446                            self.sslconn = None
2447                        self.close()
2448                    elif stripped == b'over':
2449                        if support.verbose and self.server.connectionchatty:
2450                            sys.stdout.write(" server: client closed connection\n")
2451                        self.close()
2452                        return
2453                    elif (self.server.starttls_server and
2454                          stripped == b'STARTTLS'):
2455                        if support.verbose and self.server.connectionchatty:
2456                            sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
2457                        self.write(b"OK\n")
2458                        if not self.wrap_conn():
2459                            return
2460                    elif (self.server.starttls_server and self.sslconn
2461                          and stripped == b'ENDTLS'):
2462                        if support.verbose and self.server.connectionchatty:
2463                            sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
2464                        self.write(b"OK\n")
2465                        self.sock = self.sslconn.unwrap()
2466                        self.sslconn = None
2467                        if support.verbose and self.server.connectionchatty:
2468                            sys.stdout.write(" server: connection is now unencrypted...\n")
2469                    elif stripped == b'CB tls-unique':
2470                        if support.verbose and self.server.connectionchatty:
2471                            sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
2472                        data = self.sslconn.get_channel_binding("tls-unique")
2473                        self.write(repr(data).encode("us-ascii") + b"\n")
2474                    elif stripped == b'PHA':
2475                        if support.verbose and self.server.connectionchatty:
2476                            sys.stdout.write(" server: initiating post handshake auth\n")
2477                        try:
2478                            self.sslconn.verify_client_post_handshake()
2479                        except ssl.SSLError as e:
2480                            self.write(repr(e).encode("us-ascii") + b"\n")
2481                        else:
2482                            self.write(b"OK\n")
2483                    elif stripped == b'HASCERT':
2484                        if self.sslconn.getpeercert() is not None:
2485                            self.write(b'TRUE\n')
2486                        else:
2487                            self.write(b'FALSE\n')
2488                    elif stripped == b'GETCERT':
2489                        cert = self.sslconn.getpeercert()
2490                        self.write(repr(cert).encode("us-ascii") + b"\n")
2491                    else:
2492                        if (support.verbose and
2493                            self.server.connectionchatty):
2494                            ctype = (self.sslconn and "encrypted") or "unencrypted"
2495                            sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
2496                                             % (msg, ctype, msg.lower(), ctype))
2497                        self.write(msg.lower())
2498                except (ConnectionResetError, ConnectionAbortedError):
2499                    # XXX: OpenSSL 1.1.1 sometimes raises ConnectionResetError
2500                    # when connection is not shut down gracefully.
2501                    if self.server.chatty and support.verbose:
2502                        sys.stdout.write(
2503                            " Connection reset by peer: {}\n".format(
2504                                self.addr)
2505                        )
2506                    self.close()
2507                    self.running = False
2508                except ssl.SSLError as err:
2509                    # On Windows sometimes test_pha_required_nocert receives the
2510                    # PEER_DID_NOT_RETURN_A_CERTIFICATE exception
2511                    # before the 'tlsv13 alert certificate required' exception.
2512                    # If the server is stopped when PEER_DID_NOT_RETURN_A_CERTIFICATE
2513                    # is received test_pha_required_nocert fails with ConnectionResetError
2514                    # because the underlying socket is closed
2515                    if 'PEER_DID_NOT_RETURN_A_CERTIFICATE' == err.reason:
2516                        if self.server.chatty and support.verbose:
2517                            sys.stdout.write(err.args[1])
2518                        # test_pha_required_nocert is expecting this exception
2519                        raise ssl.SSLError('tlsv13 alert certificate required')
2520                except OSError:
2521                    if self.server.chatty:
2522                        handle_error("Test server failure:\n")
2523                    self.close()
2524                    self.running = False
2525
2526                    # normally, we'd just stop here, but for the test
2527                    # harness, we want to stop the server
2528                    self.server.stop()
2529
2530    def __init__(self, certificate=None, ssl_version=None,
2531                 certreqs=None, cacerts=None,
2532                 chatty=True, connectionchatty=False, starttls_server=False,
2533                 npn_protocols=None, alpn_protocols=None,
2534                 ciphers=None, context=None):
2535        if context:
2536            self.context = context
2537        else:
2538            self.context = ssl.SSLContext(ssl_version
2539                                          if ssl_version is not None
2540                                          else ssl.PROTOCOL_TLS_SERVER)
2541            self.context.verify_mode = (certreqs if certreqs is not None
2542                                        else ssl.CERT_NONE)
2543            if cacerts:
2544                self.context.load_verify_locations(cacerts)
2545            if certificate:
2546                self.context.load_cert_chain(certificate)
2547            if npn_protocols:
2548                self.context.set_npn_protocols(npn_protocols)
2549            if alpn_protocols:
2550                self.context.set_alpn_protocols(alpn_protocols)
2551            if ciphers:
2552                self.context.set_ciphers(ciphers)
2553        self.chatty = chatty
2554        self.connectionchatty = connectionchatty
2555        self.starttls_server = starttls_server
2556        self.sock = socket.socket()
2557        self.port = support.bind_port(self.sock)
2558        self.flag = None
2559        self.active = False
2560        self.selected_npn_protocols = []
2561        self.selected_alpn_protocols = []
2562        self.shared_ciphers = []
2563        self.conn_errors = []
2564        threading.Thread.__init__(self)
2565        self.daemon = True
2566
2567    def __enter__(self):
2568        self.start(threading.Event())
2569        self.flag.wait()
2570        return self
2571
2572    def __exit__(self, *args):
2573        self.stop()
2574        self.join()
2575
2576    def start(self, flag=None):
2577        self.flag = flag
2578        threading.Thread.start(self)
2579
2580    def run(self):
2581        self.sock.settimeout(0.05)
2582        self.sock.listen()
2583        self.active = True
2584        if self.flag:
2585            # signal an event
2586            self.flag.set()
2587        while self.active:
2588            try:
2589                newconn, connaddr = self.sock.accept()
2590                if support.verbose and self.chatty:
2591                    sys.stdout.write(' server:  new connection from '
2592                                     + repr(connaddr) + '\n')
2593                handler = self.ConnectionHandler(self, newconn, connaddr)
2594                handler.start()
2595                handler.join()
2596            except socket.timeout:
2597                pass
2598            except KeyboardInterrupt:
2599                self.stop()
2600            except BaseException as e:
2601                if support.verbose and self.chatty:
2602                    sys.stdout.write(
2603                        ' connection handling failed: ' + repr(e) + '\n')
2604
2605        self.sock.close()
2606
2607    def stop(self):
2608        self.active = False
2609
2610class AsyncoreEchoServer(threading.Thread):
2611
2612    # this one's based on asyncore.dispatcher
2613
2614    class EchoServer (asyncore.dispatcher):
2615
2616        class ConnectionHandler(asyncore.dispatcher_with_send):
2617
2618            def __init__(self, conn, certfile):
2619                self.socket = test_wrap_socket(conn, server_side=True,
2620                                              certfile=certfile,
2621                                              do_handshake_on_connect=False)
2622                asyncore.dispatcher_with_send.__init__(self, self.socket)
2623                self._ssl_accepting = True
2624                self._do_ssl_handshake()
2625
2626            def readable(self):
2627                if isinstance(self.socket, ssl.SSLSocket):
2628                    while self.socket.pending() > 0:
2629                        self.handle_read_event()
2630                return True
2631
2632            def _do_ssl_handshake(self):
2633                try:
2634                    self.socket.do_handshake()
2635                except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
2636                    return
2637                except ssl.SSLEOFError:
2638                    return self.handle_close()
2639                except ssl.SSLError:
2640                    raise
2641                except OSError as err:
2642                    if err.args[0] == errno.ECONNABORTED:
2643                        return self.handle_close()
2644                else:
2645                    self._ssl_accepting = False
2646
2647            def handle_read(self):
2648                if self._ssl_accepting:
2649                    self._do_ssl_handshake()
2650                else:
2651                    data = self.recv(1024)
2652                    if support.verbose:
2653                        sys.stdout.write(" server:  read %s from client\n" % repr(data))
2654                    if not data:
2655                        self.close()
2656                    else:
2657                        self.send(data.lower())
2658
2659            def handle_close(self):
2660                self.close()
2661                if support.verbose:
2662                    sys.stdout.write(" server:  closed connection %s\n" % self.socket)
2663
2664            def handle_error(self):
2665                raise
2666
2667        def __init__(self, certfile):
2668            self.certfile = certfile
2669            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
2670            self.port = support.bind_port(sock, '')
2671            asyncore.dispatcher.__init__(self, sock)
2672            self.listen(5)
2673
2674        def handle_accepted(self, sock_obj, addr):
2675            if support.verbose:
2676                sys.stdout.write(" server:  new connection from %s:%s\n" %addr)
2677            self.ConnectionHandler(sock_obj, self.certfile)
2678
2679        def handle_error(self):
2680            raise
2681
2682    def __init__(self, certfile):
2683        self.flag = None
2684        self.active = False
2685        self.server = self.EchoServer(certfile)
2686        self.port = self.server.port
2687        threading.Thread.__init__(self)
2688        self.daemon = True
2689
2690    def __str__(self):
2691        return "<%s %s>" % (self.__class__.__name__, self.server)
2692
2693    def __enter__(self):
2694        self.start(threading.Event())
2695        self.flag.wait()
2696        return self
2697
2698    def __exit__(self, *args):
2699        if support.verbose:
2700            sys.stdout.write(" cleanup: stopping server.\n")
2701        self.stop()
2702        if support.verbose:
2703            sys.stdout.write(" cleanup: joining server thread.\n")
2704        self.join()
2705        if support.verbose:
2706            sys.stdout.write(" cleanup: successfully joined.\n")
2707        # make sure that ConnectionHandler is removed from socket_map
2708        asyncore.close_all(ignore_all=True)
2709
2710    def start (self, flag=None):
2711        self.flag = flag
2712        threading.Thread.start(self)
2713
2714    def run(self):
2715        self.active = True
2716        if self.flag:
2717            self.flag.set()
2718        while self.active:
2719            try:
2720                asyncore.loop(1)
2721            except:
2722                pass
2723
2724    def stop(self):
2725        self.active = False
2726        self.server.close()
2727
2728def server_params_test(client_context, server_context, indata=b"FOO\n",
2729                       chatty=True, connectionchatty=False, sni_name=None,
2730                       session=None):
2731    """
2732    Launch a server, connect a client to it and try various reads
2733    and writes.
2734    """
2735    stats = {}
2736    server = ThreadedEchoServer(context=server_context,
2737                                chatty=chatty,
2738                                connectionchatty=False)
2739    with server:
2740        with client_context.wrap_socket(socket.socket(),
2741                server_hostname=sni_name, session=session) as s:
2742            s.connect((HOST, server.port))
2743            for arg in [indata, bytearray(indata), memoryview(indata)]:
2744                if connectionchatty:
2745                    if support.verbose:
2746                        sys.stdout.write(
2747                            " client:  sending %r...\n" % indata)
2748                s.write(arg)
2749                outdata = s.read()
2750                if connectionchatty:
2751                    if support.verbose:
2752                        sys.stdout.write(" client:  read %r\n" % outdata)
2753                if outdata != indata.lower():
2754                    raise AssertionError(
2755                        "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
2756                        % (outdata[:20], len(outdata),
2757                           indata[:20].lower(), len(indata)))
2758            s.write(b"over\n")
2759            if connectionchatty:
2760                if support.verbose:
2761                    sys.stdout.write(" client:  closing connection.\n")
2762            stats.update({
2763                'compression': s.compression(),
2764                'cipher': s.cipher(),
2765                'peercert': s.getpeercert(),
2766                'client_alpn_protocol': s.selected_alpn_protocol(),
2767                'client_npn_protocol': s.selected_npn_protocol(),
2768                'version': s.version(),
2769                'session_reused': s.session_reused,
2770                'session': s.session,
2771            })
2772            s.close()
2773        stats['server_alpn_protocols'] = server.selected_alpn_protocols
2774        stats['server_npn_protocols'] = server.selected_npn_protocols
2775        stats['server_shared_ciphers'] = server.shared_ciphers
2776    return stats
2777
2778def try_protocol_combo(server_protocol, client_protocol, expect_success,
2779                       certsreqs=None, server_options=0, client_options=0):
2780    """
2781    Try to SSL-connect using *client_protocol* to *server_protocol*.
2782    If *expect_success* is true, assert that the connection succeeds,
2783    if it's false, assert that the connection fails.
2784    Also, if *expect_success* is a string, assert that it is the protocol
2785    version actually used by the connection.
2786    """
2787    if certsreqs is None:
2788        certsreqs = ssl.CERT_NONE
2789    certtype = {
2790        ssl.CERT_NONE: "CERT_NONE",
2791        ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
2792        ssl.CERT_REQUIRED: "CERT_REQUIRED",
2793    }[certsreqs]
2794    if support.verbose:
2795        formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
2796        sys.stdout.write(formatstr %
2797                         (ssl.get_protocol_name(client_protocol),
2798                          ssl.get_protocol_name(server_protocol),
2799                          certtype))
2800    client_context = ssl.SSLContext(client_protocol)
2801    client_context.options |= client_options
2802    server_context = ssl.SSLContext(server_protocol)
2803    server_context.options |= server_options
2804
2805    min_version = PROTOCOL_TO_TLS_VERSION.get(client_protocol, None)
2806    if (min_version is not None
2807    # SSLContext.minimum_version is only available on recent OpenSSL
2808    # (setter added in OpenSSL 1.1.0, getter added in OpenSSL 1.1.1)
2809    and hasattr(server_context, 'minimum_version')
2810    and server_protocol == ssl.PROTOCOL_TLS
2811    and server_context.minimum_version > min_version):
2812        # If OpenSSL configuration is strict and requires more recent TLS
2813        # version, we have to change the minimum to test old TLS versions.
2814        server_context.minimum_version = min_version
2815
2816    # NOTE: we must enable "ALL" ciphers on the client, otherwise an
2817    # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
2818    # starting from OpenSSL 1.0.0 (see issue #8322).
2819    if client_context.protocol == ssl.PROTOCOL_TLS:
2820        client_context.set_ciphers("ALL")
2821
2822    seclevel_workaround(server_context, client_context)
2823
2824    for ctx in (client_context, server_context):
2825        ctx.verify_mode = certsreqs
2826        ctx.load_cert_chain(SIGNED_CERTFILE)
2827        ctx.load_verify_locations(SIGNING_CA)
2828    try:
2829        stats = server_params_test(client_context, server_context,
2830                                   chatty=False, connectionchatty=False)
2831    # Protocol mismatch can result in either an SSLError, or a
2832    # "Connection reset by peer" error.
2833    except ssl.SSLError:
2834        if expect_success:
2835            raise
2836    except OSError as e:
2837        if expect_success or e.errno != errno.ECONNRESET:
2838            raise
2839    else:
2840        if not expect_success:
2841            raise AssertionError(
2842                "Client protocol %s succeeded with server protocol %s!"
2843                % (ssl.get_protocol_name(client_protocol),
2844                   ssl.get_protocol_name(server_protocol)))
2845        elif (expect_success is not True
2846              and expect_success != stats['version']):
2847            raise AssertionError("version mismatch: expected %r, got %r"
2848                                 % (expect_success, stats['version']))
2849
2850
2851class ThreadedTests(unittest.TestCase):
2852
2853    def test_echo(self):
2854        """Basic test of an SSL client connecting to a server"""
2855        if support.verbose:
2856            sys.stdout.write("\n")
2857        for protocol in PROTOCOLS:
2858            if protocol in {ssl.PROTOCOL_TLS_CLIENT, ssl.PROTOCOL_TLS_SERVER}:
2859                continue
2860            if not has_tls_protocol(protocol):
2861                continue
2862            with self.subTest(protocol=ssl._PROTOCOL_NAMES[protocol]):
2863                context = ssl.SSLContext(protocol)
2864                context.load_cert_chain(CERTFILE)
2865                seclevel_workaround(context)
2866                server_params_test(context, context,
2867                                   chatty=True, connectionchatty=True)
2868
2869        client_context, server_context, hostname = testing_context()
2870
2871        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_SERVER):
2872            server_params_test(client_context=client_context,
2873                               server_context=server_context,
2874                               chatty=True, connectionchatty=True,
2875                               sni_name=hostname)
2876
2877        client_context.check_hostname = False
2878        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_CLIENT):
2879            with self.assertRaises(ssl.SSLError) as e:
2880                server_params_test(client_context=server_context,
2881                                   server_context=client_context,
2882                                   chatty=True, connectionchatty=True,
2883                                   sni_name=hostname)
2884            self.assertIn('called a function you should not call',
2885                          str(e.exception))
2886
2887        with self.subTest(client=ssl.PROTOCOL_TLS_SERVER, server=ssl.PROTOCOL_TLS_SERVER):
2888            with self.assertRaises(ssl.SSLError) as e:
2889                server_params_test(client_context=server_context,
2890                                   server_context=server_context,
2891                                   chatty=True, connectionchatty=True)
2892            self.assertIn('called a function you should not call',
2893                          str(e.exception))
2894
2895        with self.subTest(client=ssl.PROTOCOL_TLS_CLIENT, server=ssl.PROTOCOL_TLS_CLIENT):
2896            with self.assertRaises(ssl.SSLError) as e:
2897                server_params_test(client_context=server_context,
2898                                   server_context=client_context,
2899                                   chatty=True, connectionchatty=True)
2900            self.assertIn('called a function you should not call',
2901                          str(e.exception))
2902
2903    def test_getpeercert(self):
2904        if support.verbose:
2905            sys.stdout.write("\n")
2906
2907        client_context, server_context, hostname = testing_context()
2908        server = ThreadedEchoServer(context=server_context, chatty=False)
2909        with server:
2910            with client_context.wrap_socket(socket.socket(),
2911                                            do_handshake_on_connect=False,
2912                                            server_hostname=hostname) as s:
2913                s.connect((HOST, server.port))
2914                # getpeercert() raise ValueError while the handshake isn't
2915                # done.
2916                with self.assertRaises(ValueError):
2917                    s.getpeercert()
2918                s.do_handshake()
2919                cert = s.getpeercert()
2920                self.assertTrue(cert, "Can't get peer certificate.")
2921                cipher = s.cipher()
2922                if support.verbose:
2923                    sys.stdout.write(pprint.pformat(cert) + '\n')
2924                    sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
2925                if 'subject' not in cert:
2926                    self.fail("No subject field in certificate: %s." %
2927                              pprint.pformat(cert))
2928                if ((('organizationName', 'Python Software Foundation'),)
2929                    not in cert['subject']):
2930                    self.fail(
2931                        "Missing or invalid 'organizationName' field in certificate subject; "
2932                        "should be 'Python Software Foundation'.")
2933                self.assertIn('notBefore', cert)
2934                self.assertIn('notAfter', cert)
2935                before = ssl.cert_time_to_seconds(cert['notBefore'])
2936                after = ssl.cert_time_to_seconds(cert['notAfter'])
2937                self.assertLess(before, after)
2938
2939    @unittest.skipUnless(have_verify_flags(),
2940                        "verify_flags need OpenSSL > 0.9.8")
2941    def test_crl_check(self):
2942        if support.verbose:
2943            sys.stdout.write("\n")
2944
2945        client_context, server_context, hostname = testing_context()
2946
2947        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
2948        self.assertEqual(client_context.verify_flags, ssl.VERIFY_DEFAULT | tf)
2949
2950        # VERIFY_DEFAULT should pass
2951        server = ThreadedEchoServer(context=server_context, chatty=True)
2952        with server:
2953            with client_context.wrap_socket(socket.socket(),
2954                                            server_hostname=hostname) as s:
2955                s.connect((HOST, server.port))
2956                cert = s.getpeercert()
2957                self.assertTrue(cert, "Can't get peer certificate.")
2958
2959        # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
2960        client_context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
2961
2962        server = ThreadedEchoServer(context=server_context, chatty=True)
2963        with server:
2964            with client_context.wrap_socket(socket.socket(),
2965                                            server_hostname=hostname) as s:
2966                with self.assertRaisesRegex(ssl.SSLError,
2967                                            "certificate verify failed"):
2968                    s.connect((HOST, server.port))
2969
2970        # now load a CRL file. The CRL file is signed by the CA.
2971        client_context.load_verify_locations(CRLFILE)
2972
2973        server = ThreadedEchoServer(context=server_context, chatty=True)
2974        with server:
2975            with client_context.wrap_socket(socket.socket(),
2976                                            server_hostname=hostname) as s:
2977                s.connect((HOST, server.port))
2978                cert = s.getpeercert()
2979                self.assertTrue(cert, "Can't get peer certificate.")
2980
2981    def test_check_hostname(self):
2982        if support.verbose:
2983            sys.stdout.write("\n")
2984
2985        client_context, server_context, hostname = testing_context()
2986
2987        # correct hostname should verify
2988        server = ThreadedEchoServer(context=server_context, chatty=True)
2989        with server:
2990            with client_context.wrap_socket(socket.socket(),
2991                                            server_hostname=hostname) as s:
2992                s.connect((HOST, server.port))
2993                cert = s.getpeercert()
2994                self.assertTrue(cert, "Can't get peer certificate.")
2995
2996        # incorrect hostname should raise an exception
2997        server = ThreadedEchoServer(context=server_context, chatty=True)
2998        with server:
2999            with client_context.wrap_socket(socket.socket(),
3000                                            server_hostname="invalid") as s:
3001                with self.assertRaisesRegex(
3002                        ssl.CertificateError,
3003                        "Hostname mismatch, certificate is not valid for 'invalid'."):
3004                    s.connect((HOST, server.port))
3005
3006        # missing server_hostname arg should cause an exception, too
3007        server = ThreadedEchoServer(context=server_context, chatty=True)
3008        with server:
3009            with socket.socket() as s:
3010                with self.assertRaisesRegex(ValueError,
3011                                            "check_hostname requires server_hostname"):
3012                    client_context.wrap_socket(s)
3013
3014    @unittest.skipUnless(
3015        ssl.HAS_NEVER_CHECK_COMMON_NAME, "test requires hostname_checks_common_name"
3016    )
3017    def test_hostname_checks_common_name(self):
3018        client_context, server_context, hostname = testing_context()
3019        assert client_context.hostname_checks_common_name
3020        client_context.hostname_checks_common_name = False
3021
3022        # default cert has a SAN
3023        server = ThreadedEchoServer(context=server_context, chatty=True)
3024        with server:
3025            with client_context.wrap_socket(socket.socket(),
3026                                            server_hostname=hostname) as s:
3027                s.connect((HOST, server.port))
3028
3029        client_context, server_context, hostname = testing_context(NOSANFILE)
3030        client_context.hostname_checks_common_name = False
3031        server = ThreadedEchoServer(context=server_context, chatty=True)
3032        with server:
3033            with client_context.wrap_socket(socket.socket(),
3034                                            server_hostname=hostname) as s:
3035                with self.assertRaises(ssl.SSLCertVerificationError):
3036                    s.connect((HOST, server.port))
3037
3038    def test_ecc_cert(self):
3039        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3040        client_context.load_verify_locations(SIGNING_CA)
3041        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
3042        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
3043
3044        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3045        # load ECC cert
3046        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
3047
3048        # correct hostname should verify
3049        server = ThreadedEchoServer(context=server_context, chatty=True)
3050        with server:
3051            with client_context.wrap_socket(socket.socket(),
3052                                            server_hostname=hostname) as s:
3053                s.connect((HOST, server.port))
3054                cert = s.getpeercert()
3055                self.assertTrue(cert, "Can't get peer certificate.")
3056                cipher = s.cipher()[0].split('-')
3057                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
3058
3059    def test_dual_rsa_ecc(self):
3060        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3061        client_context.load_verify_locations(SIGNING_CA)
3062        # TODO: fix TLSv1.3 once SSLContext can restrict signature
3063        #       algorithms.
3064        client_context.options |= ssl.OP_NO_TLSv1_3
3065        # only ECDSA certs
3066        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
3067        hostname = SIGNED_CERTFILE_ECC_HOSTNAME
3068
3069        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3070        # load ECC and RSA key/cert pairs
3071        server_context.load_cert_chain(SIGNED_CERTFILE_ECC)
3072        server_context.load_cert_chain(SIGNED_CERTFILE)
3073
3074        # correct hostname should verify
3075        server = ThreadedEchoServer(context=server_context, chatty=True)
3076        with server:
3077            with client_context.wrap_socket(socket.socket(),
3078                                            server_hostname=hostname) as s:
3079                s.connect((HOST, server.port))
3080                cert = s.getpeercert()
3081                self.assertTrue(cert, "Can't get peer certificate.")
3082                cipher = s.cipher()[0].split('-')
3083                self.assertTrue(cipher[:2], ('ECDHE', 'ECDSA'))
3084
3085    def test_check_hostname_idn(self):
3086        if support.verbose:
3087            sys.stdout.write("\n")
3088
3089        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3090        server_context.load_cert_chain(IDNSANSFILE)
3091
3092        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3093        context.verify_mode = ssl.CERT_REQUIRED
3094        context.check_hostname = True
3095        context.load_verify_locations(SIGNING_CA)
3096
3097        # correct hostname should verify, when specified in several
3098        # different ways
3099        idn_hostnames = [
3100            ('könig.idn.pythontest.net',
3101             'xn--knig-5qa.idn.pythontest.net'),
3102            ('xn--knig-5qa.idn.pythontest.net',
3103             'xn--knig-5qa.idn.pythontest.net'),
3104            (b'xn--knig-5qa.idn.pythontest.net',
3105             'xn--knig-5qa.idn.pythontest.net'),
3106
3107            ('königsgäßchen.idna2003.pythontest.net',
3108             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3109            ('xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
3110             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3111            (b'xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
3112             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
3113
3114            # ('königsgäßchen.idna2008.pythontest.net',
3115            #  'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3116            ('xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
3117             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3118            (b'xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
3119             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
3120
3121        ]
3122        for server_hostname, expected_hostname in idn_hostnames:
3123            server = ThreadedEchoServer(context=server_context, chatty=True)
3124            with server:
3125                with context.wrap_socket(socket.socket(),
3126                                         server_hostname=server_hostname) as s:
3127                    self.assertEqual(s.server_hostname, expected_hostname)
3128                    s.connect((HOST, server.port))
3129                    cert = s.getpeercert()
3130                    self.assertEqual(s.server_hostname, expected_hostname)
3131                    self.assertTrue(cert, "Can't get peer certificate.")
3132
3133        # incorrect hostname should raise an exception
3134        server = ThreadedEchoServer(context=server_context, chatty=True)
3135        with server:
3136            with context.wrap_socket(socket.socket(),
3137                                     server_hostname="python.example.org") as s:
3138                with self.assertRaises(ssl.CertificateError):
3139                    s.connect((HOST, server.port))
3140
3141    def test_wrong_cert_tls12(self):
3142        """Connecting when the server rejects the client's certificate
3143
3144        Launch a server with CERT_REQUIRED, and check that trying to
3145        connect to it with a wrong client certificate fails.
3146        """
3147        client_context, server_context, hostname = testing_context()
3148        # load client cert that is not signed by trusted CA
3149        client_context.load_cert_chain(CERTFILE)
3150        # require TLS client authentication
3151        server_context.verify_mode = ssl.CERT_REQUIRED
3152        # TLS 1.3 has different handshake
3153        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3154
3155        server = ThreadedEchoServer(
3156            context=server_context, chatty=True, connectionchatty=True,
3157        )
3158
3159        with server, \
3160                client_context.wrap_socket(socket.socket(),
3161                                           server_hostname=hostname) as s:
3162            try:
3163                # Expect either an SSL error about the server rejecting
3164                # the connection, or a low-level connection reset (which
3165                # sometimes happens on Windows)
3166                s.connect((HOST, server.port))
3167            except ssl.SSLError as e:
3168                if support.verbose:
3169                    sys.stdout.write("\nSSLError is %r\n" % e)
3170            except OSError as e:
3171                if e.errno != errno.ECONNRESET:
3172                    raise
3173                if support.verbose:
3174                    sys.stdout.write("\nsocket.error is %r\n" % e)
3175            else:
3176                self.fail("Use of invalid cert should have failed!")
3177
3178    @requires_tls_version('TLSv1_3')
3179    def test_wrong_cert_tls13(self):
3180        client_context, server_context, hostname = testing_context()
3181        # load client cert that is not signed by trusted CA
3182        client_context.load_cert_chain(CERTFILE)
3183        server_context.verify_mode = ssl.CERT_REQUIRED
3184        server_context.minimum_version = ssl.TLSVersion.TLSv1_3
3185        client_context.minimum_version = ssl.TLSVersion.TLSv1_3
3186
3187        server = ThreadedEchoServer(
3188            context=server_context, chatty=True, connectionchatty=True,
3189        )
3190        with server, \
3191             client_context.wrap_socket(socket.socket(),
3192                                        server_hostname=hostname) as s:
3193            # TLS 1.3 perform client cert exchange after handshake
3194            s.connect((HOST, server.port))
3195            try:
3196                s.write(b'data')
3197                s.read(4)
3198            except ssl.SSLError as e:
3199                if support.verbose:
3200                    sys.stdout.write("\nSSLError is %r\n" % e)
3201            except OSError as e:
3202                if e.errno != errno.ECONNRESET:
3203                    raise
3204                if support.verbose:
3205                    sys.stdout.write("\nsocket.error is %r\n" % e)
3206            else:
3207                self.fail("Use of invalid cert should have failed!")
3208
3209    def test_rude_shutdown(self):
3210        """A brutal shutdown of an SSL server should raise an OSError
3211        in the client when attempting handshake.
3212        """
3213        listener_ready = threading.Event()
3214        listener_gone = threading.Event()
3215
3216        s = socket.socket()
3217        port = support.bind_port(s, HOST)
3218
3219        # `listener` runs in a thread.  It sits in an accept() until
3220        # the main thread connects.  Then it rudely closes the socket,
3221        # and sets Event `listener_gone` to let the main thread know
3222        # the socket is gone.
3223        def listener():
3224            s.listen()
3225            listener_ready.set()
3226            newsock, addr = s.accept()
3227            newsock.close()
3228            s.close()
3229            listener_gone.set()
3230
3231        def connector():
3232            listener_ready.wait()
3233            with socket.socket() as c:
3234                c.connect((HOST, port))
3235                listener_gone.wait()
3236                try:
3237                    ssl_sock = test_wrap_socket(c)
3238                except OSError:
3239                    pass
3240                else:
3241                    self.fail('connecting to closed SSL socket should have failed')
3242
3243        t = threading.Thread(target=listener)
3244        t.start()
3245        try:
3246            connector()
3247        finally:
3248            t.join()
3249
3250    def test_ssl_cert_verify_error(self):
3251        if support.verbose:
3252            sys.stdout.write("\n")
3253
3254        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
3255        server_context.load_cert_chain(SIGNED_CERTFILE)
3256
3257        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3258
3259        server = ThreadedEchoServer(context=server_context, chatty=True)
3260        with server:
3261            with context.wrap_socket(socket.socket(),
3262                                     server_hostname=SIGNED_CERTFILE_HOSTNAME) as s:
3263                try:
3264                    s.connect((HOST, server.port))
3265                except ssl.SSLError as e:
3266                    msg = 'unable to get local issuer certificate'
3267                    self.assertIsInstance(e, ssl.SSLCertVerificationError)
3268                    self.assertEqual(e.verify_code, 20)
3269                    self.assertEqual(e.verify_message, msg)
3270                    self.assertIn(msg, repr(e))
3271                    self.assertIn('certificate verify failed', repr(e))
3272
3273    @requires_tls_version('SSLv2')
3274    def test_protocol_sslv2(self):
3275        """Connecting to an SSLv2 server with various client options"""
3276        if support.verbose:
3277            sys.stdout.write("\n")
3278        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
3279        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
3280        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
3281        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False)
3282        if has_tls_version('SSLv3'):
3283            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
3284        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
3285        # SSLv23 client with specific SSL options
3286        if no_sslv2_implies_sslv3_hello():
3287            # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
3288            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3289                               client_options=ssl.OP_NO_SSLv2)
3290        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3291                           client_options=ssl.OP_NO_SSLv3)
3292        try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLS, False,
3293                           client_options=ssl.OP_NO_TLSv1)
3294
3295    def test_PROTOCOL_TLS(self):
3296        """Connecting to an SSLv23 server with various client options"""
3297        if support.verbose:
3298            sys.stdout.write("\n")
3299        if has_tls_version('SSLv2'):
3300            try:
3301                try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv2, True)
3302            except OSError as x:
3303                # this fails on some older versions of OpenSSL (0.9.7l, for instance)
3304                if support.verbose:
3305                    sys.stdout.write(
3306                        " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
3307                        % str(x))
3308        if has_tls_version('SSLv3'):
3309            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False)
3310        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True)
3311        if has_tls_version('TLSv1'):
3312            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1')
3313
3314        if has_tls_version('SSLv3'):
3315            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
3316        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_OPTIONAL)
3317        if has_tls_version('TLSv1'):
3318            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3319
3320        if has_tls_version('SSLv3'):
3321            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
3322        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True, ssl.CERT_REQUIRED)
3323        if has_tls_version('TLSv1'):
3324            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3325
3326        # Server with specific SSL options
3327        if has_tls_version('SSLv3'):
3328            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_SSLv3, False,
3329                           server_options=ssl.OP_NO_SSLv3)
3330        # Will choose TLSv1
3331        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS, True,
3332                           server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
3333        if has_tls_version('TLSv1'):
3334            try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1, False,
3335                               server_options=ssl.OP_NO_TLSv1)
3336
3337    @requires_tls_version('SSLv3')
3338    def test_protocol_sslv3(self):
3339        """Connecting to an SSLv3 server with various client options"""
3340        if support.verbose:
3341            sys.stdout.write("\n")
3342        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
3343        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
3344        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
3345        if has_tls_version('SSLv2'):
3346            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
3347        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS, False,
3348                           client_options=ssl.OP_NO_SSLv3)
3349        try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
3350        if no_sslv2_implies_sslv3_hello():
3351            # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
3352            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLS,
3353                               False, client_options=ssl.OP_NO_SSLv2)
3354
3355    @requires_tls_version('TLSv1')
3356    def test_protocol_tlsv1(self):
3357        """Connecting to a TLSv1 server with various client options"""
3358        if support.verbose:
3359            sys.stdout.write("\n")
3360        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
3361        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
3362        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
3363        if has_tls_version('SSLv2'):
3364            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
3365        if has_tls_version('SSLv3'):
3366            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
3367        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLS, False,
3368                           client_options=ssl.OP_NO_TLSv1)
3369
3370    @requires_tls_version('TLSv1_1')
3371    def test_protocol_tlsv1_1(self):
3372        """Connecting to a TLSv1.1 server with various client options.
3373           Testing against older TLS versions."""
3374        if support.verbose:
3375            sys.stdout.write("\n")
3376        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3377        if has_tls_version('SSLv2'):
3378            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
3379        if has_tls_version('SSLv3'):
3380            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
3381        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLS, False,
3382                           client_options=ssl.OP_NO_TLSv1_1)
3383
3384        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
3385        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3386        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3387
3388    @requires_tls_version('TLSv1_2')
3389    def test_protocol_tlsv1_2(self):
3390        """Connecting to a TLSv1.2 server with various client options.
3391           Testing against older TLS versions."""
3392        if support.verbose:
3393            sys.stdout.write("\n")
3394        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
3395                           server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
3396                           client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
3397        if has_tls_version('SSLv2'):
3398            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
3399        if has_tls_version('SSLv3'):
3400            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
3401        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLS, False,
3402                           client_options=ssl.OP_NO_TLSv1_2)
3403
3404        try_protocol_combo(ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
3405        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
3406        try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
3407        try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
3408        try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
3409
3410    def test_starttls(self):
3411        """Switching from clear text to encrypted and back again."""
3412        msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
3413
3414        server = ThreadedEchoServer(CERTFILE,
3415                                    starttls_server=True,
3416                                    chatty=True,
3417                                    connectionchatty=True)
3418        wrapped = False
3419        with server:
3420            s = socket.socket()
3421            s.setblocking(1)
3422            s.connect((HOST, server.port))
3423            if support.verbose:
3424                sys.stdout.write("\n")
3425            for indata in msgs:
3426                if support.verbose:
3427                    sys.stdout.write(
3428                        " client:  sending %r...\n" % indata)
3429                if wrapped:
3430                    conn.write(indata)
3431                    outdata = conn.read()
3432                else:
3433                    s.send(indata)
3434                    outdata = s.recv(1024)
3435                msg = outdata.strip().lower()
3436                if indata == b"STARTTLS" and msg.startswith(b"ok"):
3437                    # STARTTLS ok, switch to secure mode
3438                    if support.verbose:
3439                        sys.stdout.write(
3440                            " client:  read %r from server, starting TLS...\n"
3441                            % msg)
3442                    conn = test_wrap_socket(s)
3443                    wrapped = True
3444                elif indata == b"ENDTLS" and msg.startswith(b"ok"):
3445                    # ENDTLS ok, switch back to clear text
3446                    if support.verbose:
3447                        sys.stdout.write(
3448                            " client:  read %r from server, ending TLS...\n"
3449                            % msg)
3450                    s = conn.unwrap()
3451                    wrapped = False
3452                else:
3453                    if support.verbose:
3454                        sys.stdout.write(
3455                            " client:  read %r from server\n" % msg)
3456            if support.verbose:
3457                sys.stdout.write(" client:  closing connection.\n")
3458            if wrapped:
3459                conn.write(b"over\n")
3460            else:
3461                s.send(b"over\n")
3462            if wrapped:
3463                conn.close()
3464            else:
3465                s.close()
3466
3467    def test_socketserver(self):
3468        """Using socketserver to create and manage SSL connections."""
3469        server = make_https_server(self, certfile=SIGNED_CERTFILE)
3470        # try to connect
3471        if support.verbose:
3472            sys.stdout.write('\n')
3473        with open(CERTFILE, 'rb') as f:
3474            d1 = f.read()
3475        d2 = ''
3476        # now fetch the same data from the HTTPS server
3477        url = 'https://localhost:%d/%s' % (
3478            server.port, os.path.split(CERTFILE)[1])
3479        context = ssl.create_default_context(cafile=SIGNING_CA)
3480        f = urllib.request.urlopen(url, context=context)
3481        try:
3482            dlen = f.info().get("content-length")
3483            if dlen and (int(dlen) > 0):
3484                d2 = f.read(int(dlen))
3485                if support.verbose:
3486                    sys.stdout.write(
3487                        " client: read %d bytes from remote server '%s'\n"
3488                        % (len(d2), server))
3489        finally:
3490            f.close()
3491        self.assertEqual(d1, d2)
3492
3493    def test_asyncore_server(self):
3494        """Check the example asyncore integration."""
3495        if support.verbose:
3496            sys.stdout.write("\n")
3497
3498        indata = b"FOO\n"
3499        server = AsyncoreEchoServer(CERTFILE)
3500        with server:
3501            s = test_wrap_socket(socket.socket())
3502            s.connect(('127.0.0.1', server.port))
3503            if support.verbose:
3504                sys.stdout.write(
3505                    " client:  sending %r...\n" % indata)
3506            s.write(indata)
3507            outdata = s.read()
3508            if support.verbose:
3509                sys.stdout.write(" client:  read %r\n" % outdata)
3510            if outdata != indata.lower():
3511                self.fail(
3512                    "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
3513                    % (outdata[:20], len(outdata),
3514                       indata[:20].lower(), len(indata)))
3515            s.write(b"over\n")
3516            if support.verbose:
3517                sys.stdout.write(" client:  closing connection.\n")
3518            s.close()
3519            if support.verbose:
3520                sys.stdout.write(" client:  connection closed.\n")
3521
3522    def test_recv_send(self):
3523        """Test recv(), send() and friends."""
3524        if support.verbose:
3525            sys.stdout.write("\n")
3526
3527        server = ThreadedEchoServer(CERTFILE,
3528                                    certreqs=ssl.CERT_NONE,
3529                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3530                                    cacerts=CERTFILE,
3531                                    chatty=True,
3532                                    connectionchatty=False)
3533        with server:
3534            s = test_wrap_socket(socket.socket(),
3535                                server_side=False,
3536                                certfile=CERTFILE,
3537                                ca_certs=CERTFILE,
3538                                cert_reqs=ssl.CERT_NONE,
3539                                ssl_version=ssl.PROTOCOL_TLS_CLIENT)
3540            s.connect((HOST, server.port))
3541            # helper methods for standardising recv* method signatures
3542            def _recv_into():
3543                b = bytearray(b"\0"*100)
3544                count = s.recv_into(b)
3545                return b[:count]
3546
3547            def _recvfrom_into():
3548                b = bytearray(b"\0"*100)
3549                count, addr = s.recvfrom_into(b)
3550                return b[:count]
3551
3552            # (name, method, expect success?, *args, return value func)
3553            send_methods = [
3554                ('send', s.send, True, [], len),
3555                ('sendto', s.sendto, False, ["some.address"], len),
3556                ('sendall', s.sendall, True, [], lambda x: None),
3557            ]
3558            # (name, method, whether to expect success, *args)
3559            recv_methods = [
3560                ('recv', s.recv, True, []),
3561                ('recvfrom', s.recvfrom, False, ["some.address"]),
3562                ('recv_into', _recv_into, True, []),
3563                ('recvfrom_into', _recvfrom_into, False, []),
3564            ]
3565            data_prefix = "PREFIX_"
3566
3567            for (meth_name, send_meth, expect_success, args,
3568                    ret_val_meth) in send_methods:
3569                indata = (data_prefix + meth_name).encode('ascii')
3570                try:
3571                    ret = send_meth(indata, *args)
3572                    msg = "sending with {}".format(meth_name)
3573                    self.assertEqual(ret, ret_val_meth(indata), msg=msg)
3574                    outdata = s.read()
3575                    if outdata != indata.lower():
3576                        self.fail(
3577                            "While sending with <<{name:s}>> bad data "
3578                            "<<{outdata:r}>> ({nout:d}) received; "
3579                            "expected <<{indata:r}>> ({nin:d})\n".format(
3580                                name=meth_name, outdata=outdata[:20],
3581                                nout=len(outdata),
3582                                indata=indata[:20], nin=len(indata)
3583                            )
3584                        )
3585                except ValueError as e:
3586                    if expect_success:
3587                        self.fail(
3588                            "Failed to send with method <<{name:s}>>; "
3589                            "expected to succeed.\n".format(name=meth_name)
3590                        )
3591                    if not str(e).startswith(meth_name):
3592                        self.fail(
3593                            "Method <<{name:s}>> failed with unexpected "
3594                            "exception message: {exp:s}\n".format(
3595                                name=meth_name, exp=e
3596                            )
3597                        )
3598
3599            for meth_name, recv_meth, expect_success, args in recv_methods:
3600                indata = (data_prefix + meth_name).encode('ascii')
3601                try:
3602                    s.send(indata)
3603                    outdata = recv_meth(*args)
3604                    if outdata != indata.lower():
3605                        self.fail(
3606                            "While receiving with <<{name:s}>> bad data "
3607                            "<<{outdata:r}>> ({nout:d}) received; "
3608                            "expected <<{indata:r}>> ({nin:d})\n".format(
3609                                name=meth_name, outdata=outdata[:20],
3610                                nout=len(outdata),
3611                                indata=indata[:20], nin=len(indata)
3612                            )
3613                        )
3614                except ValueError as e:
3615                    if expect_success:
3616                        self.fail(
3617                            "Failed to receive with method <<{name:s}>>; "
3618                            "expected to succeed.\n".format(name=meth_name)
3619                        )
3620                    if not str(e).startswith(meth_name):
3621                        self.fail(
3622                            "Method <<{name:s}>> failed with unexpected "
3623                            "exception message: {exp:s}\n".format(
3624                                name=meth_name, exp=e
3625                            )
3626                        )
3627                    # consume data
3628                    s.read()
3629
3630            # read(-1, buffer) is supported, even though read(-1) is not
3631            data = b"data"
3632            s.send(data)
3633            buffer = bytearray(len(data))
3634            self.assertEqual(s.read(-1, buffer), len(data))
3635            self.assertEqual(buffer, data)
3636
3637            # sendall accepts bytes-like objects
3638            if ctypes is not None:
3639                ubyte = ctypes.c_ubyte * len(data)
3640                byteslike = ubyte.from_buffer_copy(data)
3641                s.sendall(byteslike)
3642                self.assertEqual(s.read(), data)
3643
3644            # Make sure sendmsg et al are disallowed to avoid
3645            # inadvertent disclosure of data and/or corruption
3646            # of the encrypted data stream
3647            self.assertRaises(NotImplementedError, s.dup)
3648            self.assertRaises(NotImplementedError, s.sendmsg, [b"data"])
3649            self.assertRaises(NotImplementedError, s.recvmsg, 100)
3650            self.assertRaises(NotImplementedError,
3651                              s.recvmsg_into, [bytearray(100)])
3652            s.write(b"over\n")
3653
3654            self.assertRaises(ValueError, s.recv, -1)
3655            self.assertRaises(ValueError, s.read, -1)
3656
3657            s.close()
3658
3659    def test_recv_zero(self):
3660        server = ThreadedEchoServer(CERTFILE)
3661        server.__enter__()
3662        self.addCleanup(server.__exit__, None, None)
3663        s = socket.create_connection((HOST, server.port))
3664        self.addCleanup(s.close)
3665        s = test_wrap_socket(s, suppress_ragged_eofs=False)
3666        self.addCleanup(s.close)
3667
3668        # recv/read(0) should return no data
3669        s.send(b"data")
3670        self.assertEqual(s.recv(0), b"")
3671        self.assertEqual(s.read(0), b"")
3672        self.assertEqual(s.read(), b"data")
3673
3674        # Should not block if the other end sends no data
3675        s.setblocking(False)
3676        self.assertEqual(s.recv(0), b"")
3677        self.assertEqual(s.recv_into(bytearray()), 0)
3678
3679    def test_nonblocking_send(self):
3680        server = ThreadedEchoServer(CERTFILE,
3681                                    certreqs=ssl.CERT_NONE,
3682                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
3683                                    cacerts=CERTFILE,
3684                                    chatty=True,
3685                                    connectionchatty=False)
3686        with server:
3687            s = test_wrap_socket(socket.socket(),
3688                                server_side=False,
3689                                certfile=CERTFILE,
3690                                ca_certs=CERTFILE,
3691                                cert_reqs=ssl.CERT_NONE,
3692                                ssl_version=ssl.PROTOCOL_TLS_CLIENT)
3693            s.connect((HOST, server.port))
3694            s.setblocking(False)
3695
3696            # If we keep sending data, at some point the buffers
3697            # will be full and the call will block
3698            buf = bytearray(8192)
3699            def fill_buffer():
3700                while True:
3701                    s.send(buf)
3702            self.assertRaises((ssl.SSLWantWriteError,
3703                               ssl.SSLWantReadError), fill_buffer)
3704
3705            # Now read all the output and discard it
3706            s.setblocking(True)
3707            s.close()
3708
3709    def test_handshake_timeout(self):
3710        # Issue #5103: SSL handshake must respect the socket timeout
3711        server = socket.socket(socket.AF_INET)
3712        host = "127.0.0.1"
3713        port = support.bind_port(server)
3714        started = threading.Event()
3715        finish = False
3716
3717        def serve():
3718            server.listen()
3719            started.set()
3720            conns = []
3721            while not finish:
3722                r, w, e = select.select([server], [], [], 0.1)
3723                if server in r:
3724                    # Let the socket hang around rather than having
3725                    # it closed by garbage collection.
3726                    conns.append(server.accept()[0])
3727            for sock in conns:
3728                sock.close()
3729
3730        t = threading.Thread(target=serve)
3731        t.start()
3732        started.wait()
3733
3734        try:
3735            try:
3736                c = socket.socket(socket.AF_INET)
3737                c.settimeout(0.2)
3738                c.connect((host, port))
3739                # Will attempt handshake and time out
3740                self.assertRaisesRegex(socket.timeout, "timed out",
3741                                       test_wrap_socket, c)
3742            finally:
3743                c.close()
3744            try:
3745                c = socket.socket(socket.AF_INET)
3746                c = test_wrap_socket(c)
3747                c.settimeout(0.2)
3748                # Will attempt handshake and time out
3749                self.assertRaisesRegex(socket.timeout, "timed out",
3750                                       c.connect, (host, port))
3751            finally:
3752                c.close()
3753        finally:
3754            finish = True
3755            t.join()
3756            server.close()
3757
3758    def test_server_accept(self):
3759        # Issue #16357: accept() on a SSLSocket created through
3760        # SSLContext.wrap_socket().
3761        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3762        context.verify_mode = ssl.CERT_REQUIRED
3763        context.load_verify_locations(SIGNING_CA)
3764        context.load_cert_chain(SIGNED_CERTFILE)
3765        server = socket.socket(socket.AF_INET)
3766        host = "127.0.0.1"
3767        port = support.bind_port(server)
3768        server = context.wrap_socket(server, server_side=True)
3769        self.assertTrue(server.server_side)
3770
3771        evt = threading.Event()
3772        remote = None
3773        peer = None
3774        def serve():
3775            nonlocal remote, peer
3776            server.listen()
3777            # Block on the accept and wait on the connection to close.
3778            evt.set()
3779            remote, peer = server.accept()
3780            remote.send(remote.recv(4))
3781
3782        t = threading.Thread(target=serve)
3783        t.start()
3784        # Client wait until server setup and perform a connect.
3785        evt.wait()
3786        client = context.wrap_socket(socket.socket())
3787        client.connect((host, port))
3788        client.send(b'data')
3789        client.recv()
3790        client_addr = client.getsockname()
3791        client.close()
3792        t.join()
3793        remote.close()
3794        server.close()
3795        # Sanity checks.
3796        self.assertIsInstance(remote, ssl.SSLSocket)
3797        self.assertEqual(peer, client_addr)
3798
3799    def test_getpeercert_enotconn(self):
3800        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3801        with context.wrap_socket(socket.socket()) as sock:
3802            with self.assertRaises(OSError) as cm:
3803                sock.getpeercert()
3804            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3805
3806    def test_do_handshake_enotconn(self):
3807        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3808        with context.wrap_socket(socket.socket()) as sock:
3809            with self.assertRaises(OSError) as cm:
3810                sock.do_handshake()
3811            self.assertEqual(cm.exception.errno, errno.ENOTCONN)
3812
3813    def test_no_shared_ciphers(self):
3814        client_context, server_context, hostname = testing_context()
3815        # OpenSSL enables all TLS 1.3 ciphers, enforce TLS 1.2 for test
3816        client_context.options |= ssl.OP_NO_TLSv1_3
3817        # Force different suites on client and server
3818        client_context.set_ciphers("AES128")
3819        server_context.set_ciphers("AES256")
3820        with ThreadedEchoServer(context=server_context) as server:
3821            with client_context.wrap_socket(socket.socket(),
3822                                            server_hostname=hostname) as s:
3823                with self.assertRaises(OSError):
3824                    s.connect((HOST, server.port))
3825        self.assertIn("no shared cipher", server.conn_errors[0])
3826
3827    def test_version_basic(self):
3828        """
3829        Basic tests for SSLSocket.version().
3830        More tests are done in the test_protocol_*() methods.
3831        """
3832        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
3833        context.check_hostname = False
3834        context.verify_mode = ssl.CERT_NONE
3835        with ThreadedEchoServer(CERTFILE,
3836                                ssl_version=ssl.PROTOCOL_TLS_SERVER,
3837                                chatty=False) as server:
3838            with context.wrap_socket(socket.socket()) as s:
3839                self.assertIs(s.version(), None)
3840                self.assertIs(s._sslobj, None)
3841                s.connect((HOST, server.port))
3842                if IS_OPENSSL_1_1_1 and has_tls_version('TLSv1_3'):
3843                    self.assertEqual(s.version(), 'TLSv1.3')
3844                elif ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
3845                    self.assertEqual(s.version(), 'TLSv1.2')
3846                else:  # 0.9.8 to 1.0.1
3847                    self.assertIn(s.version(), ('TLSv1', 'TLSv1.2'))
3848            self.assertIs(s._sslobj, None)
3849            self.assertIs(s.version(), None)
3850
3851    @requires_tls_version('TLSv1_3')
3852    def test_tls1_3(self):
3853        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3854        context.load_cert_chain(CERTFILE)
3855        context.options |= (
3856            ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
3857        )
3858        with ThreadedEchoServer(context=context) as server:
3859            with context.wrap_socket(socket.socket()) as s:
3860                s.connect((HOST, server.port))
3861                self.assertIn(s.cipher()[0], {
3862                    'TLS_AES_256_GCM_SHA384',
3863                    'TLS_CHACHA20_POLY1305_SHA256',
3864                    'TLS_AES_128_GCM_SHA256',
3865                })
3866                self.assertEqual(s.version(), 'TLSv1.3')
3867
3868    @requires_minimum_version
3869    @requires_tls_version('TLSv1_2')
3870    def test_min_max_version_tlsv1_2(self):
3871        client_context, server_context, hostname = testing_context()
3872        # client TLSv1.0 to 1.2
3873        client_context.minimum_version = ssl.TLSVersion.TLSv1
3874        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3875        # server only TLSv1.2
3876        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3877        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3878
3879        with ThreadedEchoServer(context=server_context) as server:
3880            with client_context.wrap_socket(socket.socket(),
3881                                            server_hostname=hostname) as s:
3882                s.connect((HOST, server.port))
3883                self.assertEqual(s.version(), 'TLSv1.2')
3884
3885    @requires_minimum_version
3886    @requires_tls_version('TLSv1_1')
3887    def test_min_max_version_tlsv1_1(self):
3888        client_context, server_context, hostname = testing_context()
3889        # client 1.0 to 1.2, server 1.0 to 1.1
3890        client_context.minimum_version = ssl.TLSVersion.TLSv1
3891        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
3892        server_context.minimum_version = ssl.TLSVersion.TLSv1
3893        server_context.maximum_version = ssl.TLSVersion.TLSv1_1
3894        seclevel_workaround(client_context, server_context)
3895
3896        with ThreadedEchoServer(context=server_context) as server:
3897            with client_context.wrap_socket(socket.socket(),
3898                                            server_hostname=hostname) as s:
3899                s.connect((HOST, server.port))
3900                self.assertEqual(s.version(), 'TLSv1.1')
3901
3902    @requires_minimum_version
3903    @requires_tls_version('TLSv1_2')
3904    @requires_tls_version('TLSv1')
3905    def test_min_max_version_mismatch(self):
3906        client_context, server_context, hostname = testing_context()
3907        # client 1.0, server 1.2 (mismatch)
3908        server_context.maximum_version = ssl.TLSVersion.TLSv1_2
3909        server_context.minimum_version = ssl.TLSVersion.TLSv1_2
3910        client_context.maximum_version = ssl.TLSVersion.TLSv1
3911        client_context.minimum_version = ssl.TLSVersion.TLSv1
3912        seclevel_workaround(client_context, server_context)
3913
3914        with ThreadedEchoServer(context=server_context) as server:
3915            with client_context.wrap_socket(socket.socket(),
3916                                            server_hostname=hostname) as s:
3917                with self.assertRaises(ssl.SSLError) as e:
3918                    s.connect((HOST, server.port))
3919                self.assertIn("alert", str(e.exception))
3920
3921    @requires_minimum_version
3922    @requires_tls_version('SSLv3')
3923    def test_min_max_version_sslv3(self):
3924        client_context, server_context, hostname = testing_context()
3925        server_context.minimum_version = ssl.TLSVersion.SSLv3
3926        client_context.minimum_version = ssl.TLSVersion.SSLv3
3927        client_context.maximum_version = ssl.TLSVersion.SSLv3
3928        seclevel_workaround(client_context, server_context)
3929
3930        with ThreadedEchoServer(context=server_context) as server:
3931            with client_context.wrap_socket(socket.socket(),
3932                                            server_hostname=hostname) as s:
3933                s.connect((HOST, server.port))
3934                self.assertEqual(s.version(), 'SSLv3')
3935
3936    @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
3937    def test_default_ecdh_curve(self):
3938        # Issue #21015: elliptic curve-based Diffie Hellman key exchange
3939        # should be enabled by default on SSL contexts.
3940        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
3941        context.load_cert_chain(CERTFILE)
3942        # TLSv1.3 defaults to PFS key agreement and no longer has KEA in
3943        # cipher name.
3944        context.options |= ssl.OP_NO_TLSv1_3
3945        # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
3946        # explicitly using the 'ECCdraft' cipher alias.  Otherwise,
3947        # our default cipher list should prefer ECDH-based ciphers
3948        # automatically.
3949        if ssl.OPENSSL_VERSION_INFO < (1, 0, 0):
3950            context.set_ciphers("ECCdraft:ECDH")
3951        with ThreadedEchoServer(context=context) as server:
3952            with context.wrap_socket(socket.socket()) as s:
3953                s.connect((HOST, server.port))
3954                self.assertIn("ECDH", s.cipher()[0])
3955
3956    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
3957                         "'tls-unique' channel binding not available")
3958    def test_tls_unique_channel_binding(self):
3959        """Test tls-unique channel binding."""
3960        if support.verbose:
3961            sys.stdout.write("\n")
3962
3963        client_context, server_context, hostname = testing_context()
3964
3965        server = ThreadedEchoServer(context=server_context,
3966                                    chatty=True,
3967                                    connectionchatty=False)
3968
3969        with server:
3970            with client_context.wrap_socket(
3971                    socket.socket(),
3972                    server_hostname=hostname) as s:
3973                s.connect((HOST, server.port))
3974                # get the data
3975                cb_data = s.get_channel_binding("tls-unique")
3976                if support.verbose:
3977                    sys.stdout.write(
3978                        " got channel binding data: {0!r}\n".format(cb_data))
3979
3980                # check if it is sane
3981                self.assertIsNotNone(cb_data)
3982                if s.version() == 'TLSv1.3':
3983                    self.assertEqual(len(cb_data), 48)
3984                else:
3985                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
3986
3987                # and compare with the peers version
3988                s.write(b"CB tls-unique\n")
3989                peer_data_repr = s.read().strip()
3990                self.assertEqual(peer_data_repr,
3991                                 repr(cb_data).encode("us-ascii"))
3992
3993            # now, again
3994            with client_context.wrap_socket(
3995                    socket.socket(),
3996                    server_hostname=hostname) as s:
3997                s.connect((HOST, server.port))
3998                new_cb_data = s.get_channel_binding("tls-unique")
3999                if support.verbose:
4000                    sys.stdout.write(
4001                        "got another channel binding data: {0!r}\n".format(
4002                            new_cb_data)
4003                    )
4004                # is it really unique
4005                self.assertNotEqual(cb_data, new_cb_data)
4006                self.assertIsNotNone(cb_data)
4007                if s.version() == 'TLSv1.3':
4008                    self.assertEqual(len(cb_data), 48)
4009                else:
4010                    self.assertEqual(len(cb_data), 12)  # True for TLSv1
4011                s.write(b"CB tls-unique\n")
4012                peer_data_repr = s.read().strip()
4013                self.assertEqual(peer_data_repr,
4014                                 repr(new_cb_data).encode("us-ascii"))
4015
4016    def test_compression(self):
4017        client_context, server_context, hostname = testing_context()
4018        stats = server_params_test(client_context, server_context,
4019                                   chatty=True, connectionchatty=True,
4020                                   sni_name=hostname)
4021        if support.verbose:
4022            sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
4023        self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
4024
4025    @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
4026                         "ssl.OP_NO_COMPRESSION needed for this test")
4027    def test_compression_disabled(self):
4028        client_context, server_context, hostname = testing_context()
4029        client_context.options |= ssl.OP_NO_COMPRESSION
4030        server_context.options |= ssl.OP_NO_COMPRESSION
4031        stats = server_params_test(client_context, server_context,
4032                                   chatty=True, connectionchatty=True,
4033                                   sni_name=hostname)
4034        self.assertIs(stats['compression'], None)
4035
4036    def test_dh_params(self):
4037        # Check we can get a connection with ephemeral Diffie-Hellman
4038        client_context, server_context, hostname = testing_context()
4039        # test scenario needs TLS <= 1.2
4040        client_context.options |= ssl.OP_NO_TLSv1_3
4041        server_context.load_dh_params(DHFILE)
4042        server_context.set_ciphers("kEDH")
4043        server_context.options |= ssl.OP_NO_TLSv1_3
4044        stats = server_params_test(client_context, server_context,
4045                                   chatty=True, connectionchatty=True,
4046                                   sni_name=hostname)
4047        cipher = stats["cipher"][0]
4048        parts = cipher.split("-")
4049        if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
4050            self.fail("Non-DH cipher: " + cipher[0])
4051
4052    @unittest.skipUnless(HAVE_SECP_CURVES, "needs secp384r1 curve support")
4053    @unittest.skipIf(IS_OPENSSL_1_1_1, "TODO: Test doesn't work on 1.1.1")
4054    def test_ecdh_curve(self):
4055        # server secp384r1, client auto
4056        client_context, server_context, hostname = testing_context()
4057
4058        server_context.set_ecdh_curve("secp384r1")
4059        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4060        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
4061        stats = server_params_test(client_context, server_context,
4062                                   chatty=True, connectionchatty=True,
4063                                   sni_name=hostname)
4064
4065        # server auto, client secp384r1
4066        client_context, server_context, hostname = testing_context()
4067        client_context.set_ecdh_curve("secp384r1")
4068        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4069        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
4070        stats = server_params_test(client_context, server_context,
4071                                   chatty=True, connectionchatty=True,
4072                                   sni_name=hostname)
4073
4074        # server / client curve mismatch
4075        client_context, server_context, hostname = testing_context()
4076        client_context.set_ecdh_curve("prime256v1")
4077        server_context.set_ecdh_curve("secp384r1")
4078        server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
4079        server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
4080        try:
4081            stats = server_params_test(client_context, server_context,
4082                                       chatty=True, connectionchatty=True,
4083                                       sni_name=hostname)
4084        except ssl.SSLError:
4085            pass
4086        else:
4087            # OpenSSL 1.0.2 does not fail although it should.
4088            if IS_OPENSSL_1_1_0:
4089                self.fail("mismatch curve did not fail")
4090
4091    def test_selected_alpn_protocol(self):
4092        # selected_alpn_protocol() is None unless ALPN is used.
4093        client_context, server_context, hostname = testing_context()
4094        stats = server_params_test(client_context, server_context,
4095                                   chatty=True, connectionchatty=True,
4096                                   sni_name=hostname)
4097        self.assertIs(stats['client_alpn_protocol'], None)
4098
4099    @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
4100    def test_selected_alpn_protocol_if_server_uses_alpn(self):
4101        # selected_alpn_protocol() is None unless ALPN is used by the client.
4102        client_context, server_context, hostname = testing_context()
4103        server_context.set_alpn_protocols(['foo', 'bar'])
4104        stats = server_params_test(client_context, server_context,
4105                                   chatty=True, connectionchatty=True,
4106                                   sni_name=hostname)
4107        self.assertIs(stats['client_alpn_protocol'], None)
4108
4109    @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
4110    def test_alpn_protocols(self):
4111        server_protocols = ['foo', 'bar', 'milkshake']
4112        protocol_tests = [
4113            (['foo', 'bar'], 'foo'),
4114            (['bar', 'foo'], 'foo'),
4115            (['milkshake'], 'milkshake'),
4116            (['http/3.0', 'http/4.0'], None)
4117        ]
4118        for client_protocols, expected in protocol_tests:
4119            client_context, server_context, hostname = testing_context()
4120            server_context.set_alpn_protocols(server_protocols)
4121            client_context.set_alpn_protocols(client_protocols)
4122
4123            try:
4124                stats = server_params_test(client_context,
4125                                           server_context,
4126                                           chatty=True,
4127                                           connectionchatty=True,
4128                                           sni_name=hostname)
4129            except ssl.SSLError as e:
4130                stats = e
4131
4132            if (expected is None and IS_OPENSSL_1_1_0
4133                    and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)):
4134                # OpenSSL 1.1.0 to 1.1.0e raises handshake error
4135                self.assertIsInstance(stats, ssl.SSLError)
4136            else:
4137                msg = "failed trying %s (s) and %s (c).\n" \
4138                    "was expecting %s, but got %%s from the %%s" \
4139                        % (str(server_protocols), str(client_protocols),
4140                            str(expected))
4141                client_result = stats['client_alpn_protocol']
4142                self.assertEqual(client_result, expected,
4143                                 msg % (client_result, "client"))
4144                server_result = stats['server_alpn_protocols'][-1] \
4145                    if len(stats['server_alpn_protocols']) else 'nothing'
4146                self.assertEqual(server_result, expected,
4147                                 msg % (server_result, "server"))
4148
4149    def test_selected_npn_protocol(self):
4150        # selected_npn_protocol() is None unless NPN is used
4151        client_context, server_context, hostname = testing_context()
4152        stats = server_params_test(client_context, server_context,
4153                                   chatty=True, connectionchatty=True,
4154                                   sni_name=hostname)
4155        self.assertIs(stats['client_npn_protocol'], None)
4156
4157    @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
4158    def test_npn_protocols(self):
4159        server_protocols = ['http/1.1', 'spdy/2']
4160        protocol_tests = [
4161            (['http/1.1', 'spdy/2'], 'http/1.1'),
4162            (['spdy/2', 'http/1.1'], 'http/1.1'),
4163            (['spdy/2', 'test'], 'spdy/2'),
4164            (['abc', 'def'], 'abc')
4165        ]
4166        for client_protocols, expected in protocol_tests:
4167            client_context, server_context, hostname = testing_context()
4168            server_context.set_npn_protocols(server_protocols)
4169            client_context.set_npn_protocols(client_protocols)
4170            stats = server_params_test(client_context, server_context,
4171                                       chatty=True, connectionchatty=True,
4172                                       sni_name=hostname)
4173            msg = "failed trying %s (s) and %s (c).\n" \
4174                  "was expecting %s, but got %%s from the %%s" \
4175                      % (str(server_protocols), str(client_protocols),
4176                         str(expected))
4177            client_result = stats['client_npn_protocol']
4178            self.assertEqual(client_result, expected, msg % (client_result, "client"))
4179            server_result = stats['server_npn_protocols'][-1] \
4180                if len(stats['server_npn_protocols']) else 'nothing'
4181            self.assertEqual(server_result, expected, msg % (server_result, "server"))
4182
4183    def sni_contexts(self):
4184        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4185        server_context.load_cert_chain(SIGNED_CERTFILE)
4186        other_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4187        other_context.load_cert_chain(SIGNED_CERTFILE2)
4188        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4189        client_context.load_verify_locations(SIGNING_CA)
4190        return server_context, other_context, client_context
4191
4192    def check_common_name(self, stats, name):
4193        cert = stats['peercert']
4194        self.assertIn((('commonName', name),), cert['subject'])
4195
4196    @needs_sni
4197    def test_sni_callback(self):
4198        calls = []
4199        server_context, other_context, client_context = self.sni_contexts()
4200
4201        client_context.check_hostname = False
4202
4203        def servername_cb(ssl_sock, server_name, initial_context):
4204            calls.append((server_name, initial_context))
4205            if server_name is not None:
4206                ssl_sock.context = other_context
4207        server_context.set_servername_callback(servername_cb)
4208
4209        stats = server_params_test(client_context, server_context,
4210                                   chatty=True,
4211                                   sni_name='supermessage')
4212        # The hostname was fetched properly, and the certificate was
4213        # changed for the connection.
4214        self.assertEqual(calls, [("supermessage", server_context)])
4215        # CERTFILE4 was selected
4216        self.check_common_name(stats, 'fakehostname')
4217
4218        calls = []
4219        # The callback is called with server_name=None
4220        stats = server_params_test(client_context, server_context,
4221                                   chatty=True,
4222                                   sni_name=None)
4223        self.assertEqual(calls, [(None, server_context)])
4224        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4225
4226        # Check disabling the callback
4227        calls = []
4228        server_context.set_servername_callback(None)
4229
4230        stats = server_params_test(client_context, server_context,
4231                                   chatty=True,
4232                                   sni_name='notfunny')
4233        # Certificate didn't change
4234        self.check_common_name(stats, SIGNED_CERTFILE_HOSTNAME)
4235        self.assertEqual(calls, [])
4236
4237    @needs_sni
4238    def test_sni_callback_alert(self):
4239        # Returning a TLS alert is reflected to the connecting client
4240        server_context, other_context, client_context = self.sni_contexts()
4241
4242        def cb_returning_alert(ssl_sock, server_name, initial_context):
4243            return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
4244        server_context.set_servername_callback(cb_returning_alert)
4245        with self.assertRaises(ssl.SSLError) as cm:
4246            stats = server_params_test(client_context, server_context,
4247                                       chatty=False,
4248                                       sni_name='supermessage')
4249        self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
4250
4251    @needs_sni
4252    def test_sni_callback_raising(self):
4253        # Raising fails the connection with a TLS handshake failure alert.
4254        server_context, other_context, client_context = self.sni_contexts()
4255
4256        def cb_raising(ssl_sock, server_name, initial_context):
4257            1/0
4258        server_context.set_servername_callback(cb_raising)
4259
4260        with support.catch_unraisable_exception() as catch:
4261            with self.assertRaises(ssl.SSLError) as cm:
4262                stats = server_params_test(client_context, server_context,
4263                                           chatty=False,
4264                                           sni_name='supermessage')
4265
4266            self.assertEqual(cm.exception.reason,
4267                             'SSLV3_ALERT_HANDSHAKE_FAILURE')
4268            self.assertEqual(catch.unraisable.exc_type, ZeroDivisionError)
4269
4270    @needs_sni
4271    def test_sni_callback_wrong_return_type(self):
4272        # Returning the wrong return type terminates the TLS connection
4273        # with an internal error alert.
4274        server_context, other_context, client_context = self.sni_contexts()
4275
4276        def cb_wrong_return_type(ssl_sock, server_name, initial_context):
4277            return "foo"
4278        server_context.set_servername_callback(cb_wrong_return_type)
4279
4280        with support.catch_unraisable_exception() as catch:
4281            with self.assertRaises(ssl.SSLError) as cm:
4282                stats = server_params_test(client_context, server_context,
4283                                           chatty=False,
4284                                           sni_name='supermessage')
4285
4286
4287            self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
4288            self.assertEqual(catch.unraisable.exc_type, TypeError)
4289
4290    def test_shared_ciphers(self):
4291        client_context, server_context, hostname = testing_context()
4292        client_context.set_ciphers("AES128:AES256")
4293        server_context.set_ciphers("AES256")
4294        expected_algs = [
4295            "AES256", "AES-256",
4296            # TLS 1.3 ciphers are always enabled
4297            "TLS_CHACHA20", "TLS_AES",
4298        ]
4299
4300        stats = server_params_test(client_context, server_context,
4301                                   sni_name=hostname)
4302        ciphers = stats['server_shared_ciphers'][0]
4303        self.assertGreater(len(ciphers), 0)
4304        for name, tls_version, bits in ciphers:
4305            if not any(alg in name for alg in expected_algs):
4306                self.fail(name)
4307
4308    def test_read_write_after_close_raises_valuerror(self):
4309        client_context, server_context, hostname = testing_context()
4310        server = ThreadedEchoServer(context=server_context, chatty=False)
4311
4312        with server:
4313            s = client_context.wrap_socket(socket.socket(),
4314                                           server_hostname=hostname)
4315            s.connect((HOST, server.port))
4316            s.close()
4317
4318            self.assertRaises(ValueError, s.read, 1024)
4319            self.assertRaises(ValueError, s.write, b'hello')
4320
4321    def test_sendfile(self):
4322        TEST_DATA = b"x" * 512
4323        with open(support.TESTFN, 'wb') as f:
4324            f.write(TEST_DATA)
4325        self.addCleanup(support.unlink, support.TESTFN)
4326        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
4327        context.verify_mode = ssl.CERT_REQUIRED
4328        context.load_verify_locations(SIGNING_CA)
4329        context.load_cert_chain(SIGNED_CERTFILE)
4330        server = ThreadedEchoServer(context=context, chatty=False)
4331        with server:
4332            with context.wrap_socket(socket.socket()) as s:
4333                s.connect((HOST, server.port))
4334                with open(support.TESTFN, 'rb') as file:
4335                    s.sendfile(file)
4336                    self.assertEqual(s.recv(1024), TEST_DATA)
4337
4338    def test_session(self):
4339        client_context, server_context, hostname = testing_context()
4340        # TODO: sessions aren't compatible with TLSv1.3 yet
4341        client_context.options |= ssl.OP_NO_TLSv1_3
4342
4343        # first connection without session
4344        stats = server_params_test(client_context, server_context,
4345                                   sni_name=hostname)
4346        session = stats['session']
4347        self.assertTrue(session.id)
4348        self.assertGreater(session.time, 0)
4349        self.assertGreater(session.timeout, 0)
4350        self.assertTrue(session.has_ticket)
4351        if ssl.OPENSSL_VERSION_INFO > (1, 0, 1):
4352            self.assertGreater(session.ticket_lifetime_hint, 0)
4353        self.assertFalse(stats['session_reused'])
4354        sess_stat = server_context.session_stats()
4355        self.assertEqual(sess_stat['accept'], 1)
4356        self.assertEqual(sess_stat['hits'], 0)
4357
4358        # reuse session
4359        stats = server_params_test(client_context, server_context,
4360                                   session=session, sni_name=hostname)
4361        sess_stat = server_context.session_stats()
4362        self.assertEqual(sess_stat['accept'], 2)
4363        self.assertEqual(sess_stat['hits'], 1)
4364        self.assertTrue(stats['session_reused'])
4365        session2 = stats['session']
4366        self.assertEqual(session2.id, session.id)
4367        self.assertEqual(session2, session)
4368        self.assertIsNot(session2, session)
4369        self.assertGreaterEqual(session2.time, session.time)
4370        self.assertGreaterEqual(session2.timeout, session.timeout)
4371
4372        # another one without session
4373        stats = server_params_test(client_context, server_context,
4374                                   sni_name=hostname)
4375        self.assertFalse(stats['session_reused'])
4376        session3 = stats['session']
4377        self.assertNotEqual(session3.id, session.id)
4378        self.assertNotEqual(session3, session)
4379        sess_stat = server_context.session_stats()
4380        self.assertEqual(sess_stat['accept'], 3)
4381        self.assertEqual(sess_stat['hits'], 1)
4382
4383        # reuse session again
4384        stats = server_params_test(client_context, server_context,
4385                                   session=session, sni_name=hostname)
4386        self.assertTrue(stats['session_reused'])
4387        session4 = stats['session']
4388        self.assertEqual(session4.id, session.id)
4389        self.assertEqual(session4, session)
4390        self.assertGreaterEqual(session4.time, session.time)
4391        self.assertGreaterEqual(session4.timeout, session.timeout)
4392        sess_stat = server_context.session_stats()
4393        self.assertEqual(sess_stat['accept'], 4)
4394        self.assertEqual(sess_stat['hits'], 2)
4395
4396    def test_session_handling(self):
4397        client_context, server_context, hostname = testing_context()
4398        client_context2, _, _ = testing_context()
4399
4400        # TODO: session reuse does not work with TLSv1.3
4401        client_context.options |= ssl.OP_NO_TLSv1_3
4402        client_context2.options |= ssl.OP_NO_TLSv1_3
4403
4404        server = ThreadedEchoServer(context=server_context, chatty=False)
4405        with server:
4406            with client_context.wrap_socket(socket.socket(),
4407                                            server_hostname=hostname) as s:
4408                # session is None before handshake
4409                self.assertEqual(s.session, None)
4410                self.assertEqual(s.session_reused, None)
4411                s.connect((HOST, server.port))
4412                session = s.session
4413                self.assertTrue(session)
4414                with self.assertRaises(TypeError) as e:
4415                    s.session = object
4416                self.assertEqual(str(e.exception), 'Value is not a SSLSession.')
4417
4418            with client_context.wrap_socket(socket.socket(),
4419                                            server_hostname=hostname) as s:
4420                s.connect((HOST, server.port))
4421                # cannot set session after handshake
4422                with self.assertRaises(ValueError) as e:
4423                    s.session = session
4424                self.assertEqual(str(e.exception),
4425                                 'Cannot set session after handshake.')
4426
4427            with client_context.wrap_socket(socket.socket(),
4428                                            server_hostname=hostname) as s:
4429                # can set session before handshake and before the
4430                # connection was established
4431                s.session = session
4432                s.connect((HOST, server.port))
4433                self.assertEqual(s.session.id, session.id)
4434                self.assertEqual(s.session, session)
4435                self.assertEqual(s.session_reused, True)
4436
4437            with client_context2.wrap_socket(socket.socket(),
4438                                             server_hostname=hostname) as s:
4439                # cannot re-use session with a different SSLContext
4440                with self.assertRaises(ValueError) as e:
4441                    s.session = session
4442                    s.connect((HOST, server.port))
4443                self.assertEqual(str(e.exception),
4444                                 'Session refers to a different SSLContext.')
4445
4446
4447@unittest.skipUnless(has_tls_version('TLSv1_3'), "Test needs TLS 1.3")
4448class TestPostHandshakeAuth(unittest.TestCase):
4449    def test_pha_setter(self):
4450        protocols = [
4451            ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER, ssl.PROTOCOL_TLS_CLIENT
4452        ]
4453        for protocol in protocols:
4454            ctx = ssl.SSLContext(protocol)
4455            self.assertEqual(ctx.post_handshake_auth, False)
4456
4457            ctx.post_handshake_auth = True
4458            self.assertEqual(ctx.post_handshake_auth, True)
4459
4460            ctx.verify_mode = ssl.CERT_REQUIRED
4461            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4462            self.assertEqual(ctx.post_handshake_auth, True)
4463
4464            ctx.post_handshake_auth = False
4465            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
4466            self.assertEqual(ctx.post_handshake_auth, False)
4467
4468            ctx.verify_mode = ssl.CERT_OPTIONAL
4469            ctx.post_handshake_auth = True
4470            self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
4471            self.assertEqual(ctx.post_handshake_auth, True)
4472
4473    def test_pha_required(self):
4474        client_context, server_context, hostname = testing_context()
4475        server_context.post_handshake_auth = True
4476        server_context.verify_mode = ssl.CERT_REQUIRED
4477        client_context.post_handshake_auth = True
4478        client_context.load_cert_chain(SIGNED_CERTFILE)
4479
4480        server = ThreadedEchoServer(context=server_context, chatty=False)
4481        with server:
4482            with client_context.wrap_socket(socket.socket(),
4483                                            server_hostname=hostname) as s:
4484                s.connect((HOST, server.port))
4485                s.write(b'HASCERT')
4486                self.assertEqual(s.recv(1024), b'FALSE\n')
4487                s.write(b'PHA')
4488                self.assertEqual(s.recv(1024), b'OK\n')
4489                s.write(b'HASCERT')
4490                self.assertEqual(s.recv(1024), b'TRUE\n')
4491                # PHA method just returns true when cert is already available
4492                s.write(b'PHA')
4493                self.assertEqual(s.recv(1024), b'OK\n')
4494                s.write(b'GETCERT')
4495                cert_text = s.recv(4096).decode('us-ascii')
4496                self.assertIn('Python Software Foundation CA', cert_text)
4497
4498    def test_pha_required_nocert(self):
4499        client_context, server_context, hostname = testing_context()
4500        server_context.post_handshake_auth = True
4501        server_context.verify_mode = ssl.CERT_REQUIRED
4502        client_context.post_handshake_auth = True
4503
4504        # Ignore expected SSLError in ConnectionHandler of ThreadedEchoServer
4505        # (it is only raised sometimes on Windows)
4506        with support.catch_threading_exception() as cm:
4507            server = ThreadedEchoServer(context=server_context, chatty=False)
4508            with server:
4509                with client_context.wrap_socket(socket.socket(),
4510                                                server_hostname=hostname) as s:
4511                    s.connect((HOST, server.port))
4512                    s.write(b'PHA')
4513                    # receive CertificateRequest
4514                    self.assertEqual(s.recv(1024), b'OK\n')
4515                    # send empty Certificate + Finish
4516                    s.write(b'HASCERT')
4517                    # receive alert
4518                    with self.assertRaisesRegex(
4519                            ssl.SSLError,
4520                            'tlsv13 alert certificate required'):
4521                        s.recv(1024)
4522
4523    def test_pha_optional(self):
4524        if support.verbose:
4525            sys.stdout.write("\n")
4526
4527        client_context, server_context, hostname = testing_context()
4528        server_context.post_handshake_auth = True
4529        server_context.verify_mode = ssl.CERT_REQUIRED
4530        client_context.post_handshake_auth = True
4531        client_context.load_cert_chain(SIGNED_CERTFILE)
4532
4533        # check CERT_OPTIONAL
4534        server_context.verify_mode = ssl.CERT_OPTIONAL
4535        server = ThreadedEchoServer(context=server_context, chatty=False)
4536        with server:
4537            with client_context.wrap_socket(socket.socket(),
4538                                            server_hostname=hostname) as s:
4539                s.connect((HOST, server.port))
4540                s.write(b'HASCERT')
4541                self.assertEqual(s.recv(1024), b'FALSE\n')
4542                s.write(b'PHA')
4543                self.assertEqual(s.recv(1024), b'OK\n')
4544                s.write(b'HASCERT')
4545                self.assertEqual(s.recv(1024), b'TRUE\n')
4546
4547    def test_pha_optional_nocert(self):
4548        if support.verbose:
4549            sys.stdout.write("\n")
4550
4551        client_context, server_context, hostname = testing_context()
4552        server_context.post_handshake_auth = True
4553        server_context.verify_mode = ssl.CERT_OPTIONAL
4554        client_context.post_handshake_auth = True
4555
4556        server = ThreadedEchoServer(context=server_context, chatty=False)
4557        with server:
4558            with client_context.wrap_socket(socket.socket(),
4559                                            server_hostname=hostname) as s:
4560                s.connect((HOST, server.port))
4561                s.write(b'HASCERT')
4562                self.assertEqual(s.recv(1024), b'FALSE\n')
4563                s.write(b'PHA')
4564                self.assertEqual(s.recv(1024), b'OK\n')
4565                # optional doesn't fail when client does not have a cert
4566                s.write(b'HASCERT')
4567                self.assertEqual(s.recv(1024), b'FALSE\n')
4568
4569    def test_pha_no_pha_client(self):
4570        client_context, server_context, hostname = testing_context()
4571        server_context.post_handshake_auth = True
4572        server_context.verify_mode = ssl.CERT_REQUIRED
4573        client_context.load_cert_chain(SIGNED_CERTFILE)
4574
4575        server = ThreadedEchoServer(context=server_context, chatty=False)
4576        with server:
4577            with client_context.wrap_socket(socket.socket(),
4578                                            server_hostname=hostname) as s:
4579                s.connect((HOST, server.port))
4580                with self.assertRaisesRegex(ssl.SSLError, 'not server'):
4581                    s.verify_client_post_handshake()
4582                s.write(b'PHA')
4583                self.assertIn(b'extension not received', s.recv(1024))
4584
4585    def test_pha_no_pha_server(self):
4586        # server doesn't have PHA enabled, cert is requested in handshake
4587        client_context, server_context, hostname = testing_context()
4588        server_context.verify_mode = ssl.CERT_REQUIRED
4589        client_context.post_handshake_auth = True
4590        client_context.load_cert_chain(SIGNED_CERTFILE)
4591
4592        server = ThreadedEchoServer(context=server_context, chatty=False)
4593        with server:
4594            with client_context.wrap_socket(socket.socket(),
4595                                            server_hostname=hostname) as s:
4596                s.connect((HOST, server.port))
4597                s.write(b'HASCERT')
4598                self.assertEqual(s.recv(1024), b'TRUE\n')
4599                # PHA doesn't fail if there is already a cert
4600                s.write(b'PHA')
4601                self.assertEqual(s.recv(1024), b'OK\n')
4602                s.write(b'HASCERT')
4603                self.assertEqual(s.recv(1024), b'TRUE\n')
4604
4605    def test_pha_not_tls13(self):
4606        # TLS 1.2
4607        client_context, server_context, hostname = testing_context()
4608        server_context.verify_mode = ssl.CERT_REQUIRED
4609        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
4610        client_context.post_handshake_auth = True
4611        client_context.load_cert_chain(SIGNED_CERTFILE)
4612
4613        server = ThreadedEchoServer(context=server_context, chatty=False)
4614        with server:
4615            with client_context.wrap_socket(socket.socket(),
4616                                            server_hostname=hostname) as s:
4617                s.connect((HOST, server.port))
4618                # PHA fails for TLS != 1.3
4619                s.write(b'PHA')
4620                self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
4621
4622    def test_bpo37428_pha_cert_none(self):
4623        # verify that post_handshake_auth does not implicitly enable cert
4624        # validation.
4625        hostname = SIGNED_CERTFILE_HOSTNAME
4626        client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4627        client_context.post_handshake_auth = True
4628        client_context.load_cert_chain(SIGNED_CERTFILE)
4629        # no cert validation and CA on client side
4630        client_context.check_hostname = False
4631        client_context.verify_mode = ssl.CERT_NONE
4632
4633        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
4634        server_context.load_cert_chain(SIGNED_CERTFILE)
4635        server_context.load_verify_locations(SIGNING_CA)
4636        server_context.post_handshake_auth = True
4637        server_context.verify_mode = ssl.CERT_REQUIRED
4638
4639        server = ThreadedEchoServer(context=server_context, chatty=False)
4640        with server:
4641            with client_context.wrap_socket(socket.socket(),
4642                                            server_hostname=hostname) as s:
4643                s.connect((HOST, server.port))
4644                s.write(b'HASCERT')
4645                self.assertEqual(s.recv(1024), b'FALSE\n')
4646                s.write(b'PHA')
4647                self.assertEqual(s.recv(1024), b'OK\n')
4648                s.write(b'HASCERT')
4649                self.assertEqual(s.recv(1024), b'TRUE\n')
4650                # server cert has not been validated
4651                self.assertEqual(s.getpeercert(), {})
4652
4653
4654HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
4655requires_keylog = unittest.skipUnless(
4656    HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')
4657
4658class TestSSLDebug(unittest.TestCase):
4659
4660    def keylog_lines(self, fname=support.TESTFN):
4661        with open(fname) as f:
4662            return len(list(f))
4663
4664    @requires_keylog
4665    def test_keylog_defaults(self):
4666        self.addCleanup(support.unlink, support.TESTFN)
4667        ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4668        self.assertEqual(ctx.keylog_filename, None)
4669
4670        self.assertFalse(os.path.isfile(support.TESTFN))
4671        ctx.keylog_filename = support.TESTFN
4672        self.assertEqual(ctx.keylog_filename, support.TESTFN)
4673        self.assertTrue(os.path.isfile(support.TESTFN))
4674        self.assertEqual(self.keylog_lines(), 1)
4675
4676        ctx.keylog_filename = None
4677        self.assertEqual(ctx.keylog_filename, None)
4678
4679        with self.assertRaises((IsADirectoryError, PermissionError)):
4680            # Windows raises PermissionError
4681            ctx.keylog_filename = os.path.dirname(
4682                os.path.abspath(support.TESTFN))
4683
4684        with self.assertRaises(TypeError):
4685            ctx.keylog_filename = 1
4686
4687    @requires_keylog
4688    def test_keylog_filename(self):
4689        self.addCleanup(support.unlink, support.TESTFN)
4690        client_context, server_context, hostname = testing_context()
4691
4692        client_context.keylog_filename = support.TESTFN
4693        server = ThreadedEchoServer(context=server_context, chatty=False)
4694        with server:
4695            with client_context.wrap_socket(socket.socket(),
4696                                            server_hostname=hostname) as s:
4697                s.connect((HOST, server.port))
4698        # header, 5 lines for TLS 1.3
4699        self.assertEqual(self.keylog_lines(), 6)
4700
4701        client_context.keylog_filename = None
4702        server_context.keylog_filename = support.TESTFN
4703        server = ThreadedEchoServer(context=server_context, chatty=False)
4704        with server:
4705            with client_context.wrap_socket(socket.socket(),
4706                                            server_hostname=hostname) as s:
4707                s.connect((HOST, server.port))
4708        self.assertGreaterEqual(self.keylog_lines(), 11)
4709
4710        client_context.keylog_filename = support.TESTFN
4711        server_context.keylog_filename = support.TESTFN
4712        server = ThreadedEchoServer(context=server_context, chatty=False)
4713        with server:
4714            with client_context.wrap_socket(socket.socket(),
4715                                            server_hostname=hostname) as s:
4716                s.connect((HOST, server.port))
4717        self.assertGreaterEqual(self.keylog_lines(), 21)
4718
4719        client_context.keylog_filename = None
4720        server_context.keylog_filename = None
4721
4722    @requires_keylog
4723    @unittest.skipIf(sys.flags.ignore_environment,
4724                     "test is not compatible with ignore_environment")
4725    def test_keylog_env(self):
4726        self.addCleanup(support.unlink, support.TESTFN)
4727        with unittest.mock.patch.dict(os.environ):
4728            os.environ['SSLKEYLOGFILE'] = support.TESTFN
4729            self.assertEqual(os.environ['SSLKEYLOGFILE'], support.TESTFN)
4730
4731            ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
4732            self.assertEqual(ctx.keylog_filename, None)
4733
4734            ctx = ssl.create_default_context()
4735            self.assertEqual(ctx.keylog_filename, support.TESTFN)
4736
4737            ctx = ssl._create_stdlib_context()
4738            self.assertEqual(ctx.keylog_filename, support.TESTFN)
4739
4740    def test_msg_callback(self):
4741        client_context, server_context, hostname = testing_context()
4742
4743        def msg_cb(conn, direction, version, content_type, msg_type, data):
4744            pass
4745
4746        self.assertIs(client_context._msg_callback, None)
4747        client_context._msg_callback = msg_cb
4748        self.assertIs(client_context._msg_callback, msg_cb)
4749        with self.assertRaises(TypeError):
4750            client_context._msg_callback = object()
4751
4752    def test_msg_callback_tls12(self):
4753        client_context, server_context, hostname = testing_context()
4754        client_context.options |= ssl.OP_NO_TLSv1_3
4755
4756        msg = []
4757
4758        def msg_cb(conn, direction, version, content_type, msg_type, data):
4759            self.assertIsInstance(conn, ssl.SSLSocket)
4760            self.assertIsInstance(data, bytes)
4761            self.assertIn(direction, {'read', 'write'})
4762            msg.append((direction, version, content_type, msg_type))
4763
4764        client_context._msg_callback = msg_cb
4765
4766        server = ThreadedEchoServer(context=server_context, chatty=False)
4767        with server:
4768            with client_context.wrap_socket(socket.socket(),
4769                                            server_hostname=hostname) as s:
4770                s.connect((HOST, server.port))
4771
4772        self.assertIn(
4773            ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
4774             _TLSMessageType.SERVER_KEY_EXCHANGE),
4775            msg
4776        )
4777        self.assertIn(
4778            ("write", TLSVersion.TLSv1_2, _TLSContentType.CHANGE_CIPHER_SPEC,
4779             _TLSMessageType.CHANGE_CIPHER_SPEC),
4780            msg
4781        )
4782
4783    def test_msg_callback_deadlock_bpo43577(self):
4784        client_context, server_context, hostname = testing_context()
4785        server_context2 = testing_context()[1]
4786
4787        def msg_cb(conn, direction, version, content_type, msg_type, data):
4788            pass
4789
4790        def sni_cb(sock, servername, ctx):
4791            sock.context = server_context2
4792
4793        server_context._msg_callback = msg_cb
4794        server_context.sni_callback = sni_cb
4795
4796        server = ThreadedEchoServer(context=server_context, chatty=False)
4797        with server:
4798            with client_context.wrap_socket(socket.socket(),
4799                                            server_hostname=hostname) as s:
4800                s.connect((HOST, server.port))
4801            with client_context.wrap_socket(socket.socket(),
4802                                            server_hostname=hostname) as s:
4803                s.connect((HOST, server.port))
4804
4805
4806def test_main(verbose=False):
4807    if support.verbose:
4808        plats = {
4809            'Mac': platform.mac_ver,
4810            'Windows': platform.win32_ver,
4811        }
4812        for name, func in plats.items():
4813            plat = func()
4814            if plat and plat[0]:
4815                plat = '%s %r' % (name, plat)
4816                break
4817        else:
4818            plat = repr(platform.platform())
4819        print("test_ssl: testing with %r %r" %
4820            (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO))
4821        print("          under %s" % plat)
4822        print("          HAS_SNI = %r" % ssl.HAS_SNI)
4823        print("          OP_ALL = 0x%8x" % ssl.OP_ALL)
4824        try:
4825            print("          OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1)
4826        except AttributeError:
4827            pass
4828
4829    for filename in [
4830        CERTFILE, BYTES_CERTFILE,
4831        ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY,
4832        SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA,
4833        BADCERT, BADKEY, EMPTYCERT]:
4834        if not os.path.exists(filename):
4835            raise support.TestFailed("Can't read certificate file %r" % filename)
4836
4837    tests = [
4838        ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
4839        SSLObjectTests, SimpleBackgroundTests, ThreadedTests,
4840        TestPostHandshakeAuth, TestSSLDebug
4841    ]
4842
4843    if support.is_resource_enabled('network'):
4844        tests.append(NetworkedTests)
4845
4846    thread_info = support.threading_setup()
4847    try:
4848        support.run_unittest(*tests)
4849    finally:
4850        support.threading_cleanup(*thread_info)
4851
4852if __name__ == "__main__":
4853    test_main()
4854