1# -*- coding: utf-8 -*-
2# Test the support for SSL and sockets
3
4import sys
5import unittest
6from test import test_support as support
7from test.script_helper import assert_python_ok
8import asyncore
9import socket
10import select
11import time
12import datetime
13import gc
14import os
15import errno
16import pprint
17import shutil
18import urllib2
19import traceback
20import weakref
21import platform
22import re
23import functools
24from contextlib import closing
25
26ssl = support.import_module("ssl")
27
28PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
29HOST = support.HOST
30IS_LIBRESSL = ssl.OPENSSL_VERSION.startswith('LibreSSL')
31IS_OPENSSL_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0)
32
33
34def data_file(*name):
35    return os.path.join(os.path.dirname(__file__), *name)
36
37# The custom key and certificate files used in test_ssl are generated
38# using Lib/test/make_ssl_certs.py.
39# Other certificates are simply fetched from the Internet servers they
40# are meant to authenticate.
41
42CERTFILE = data_file("keycert.pem")
43BYTES_CERTFILE = CERTFILE.encode(sys.getfilesystemencoding())
44ONLYCERT = data_file("ssl_cert.pem")
45ONLYKEY = data_file("ssl_key.pem")
46BYTES_ONLYCERT = ONLYCERT.encode(sys.getfilesystemencoding())
47BYTES_ONLYKEY = ONLYKEY.encode(sys.getfilesystemencoding())
48CERTFILE_PROTECTED = data_file("keycert.passwd.pem")
49ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem")
50KEY_PASSWORD = "somepass"
51CAPATH = data_file("capath")
52BYTES_CAPATH = CAPATH.encode(sys.getfilesystemencoding())
53CAFILE_NEURONIO = data_file("capath", "4e1295a3.0")
54CAFILE_CACERT = data_file("capath", "5ed36f99.0")
55
56
57# empty CRL
58CRLFILE = data_file("revocation.crl")
59
60# Two keys and certs signed by the same CA (for SNI tests)
61SIGNED_CERTFILE = data_file("keycert3.pem")
62SIGNED_CERTFILE2 = data_file("keycert4.pem")
63SIGNING_CA = data_file("pycacert.pem")
64# cert with all kinds of subject alt names
65ALLSANFILE = data_file("allsans.pem")
66
67REMOTE_HOST = "self-signed.pythontest.net"
68REMOTE_ROOT_CERT = data_file("selfsigned_pythontestdotnet.pem")
69
70EMPTYCERT = data_file("nullcert.pem")
71BADCERT = data_file("badcert.pem")
72NONEXISTINGCERT = data_file("XXXnonexisting.pem")
73BADKEY = data_file("badkey.pem")
74NOKIACERT = data_file("nokia.pem")
75NULLBYTECERT = data_file("nullbytecert.pem")
76TALOS_INVALID_CRLDP = data_file("talos-2019-0758.pem")
77
78DHFILE = data_file("ffdh3072.pem")
79BYTES_DHFILE = DHFILE.encode(sys.getfilesystemencoding())
80
81# Not defined in all versions of OpenSSL
82OP_NO_COMPRESSION = getattr(ssl, "OP_NO_COMPRESSION", 0)
83OP_SINGLE_DH_USE = getattr(ssl, "OP_SINGLE_DH_USE", 0)
84OP_SINGLE_ECDH_USE = getattr(ssl, "OP_SINGLE_ECDH_USE", 0)
85OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
86OP_ENABLE_MIDDLEBOX_COMPAT = getattr(ssl, "OP_ENABLE_MIDDLEBOX_COMPAT", 0)
87
88
89def handle_error(prefix):
90    exc_format = ' '.join(traceback.format_exception(*sys.exc_info()))
91    if support.verbose:
92        sys.stdout.write(prefix + exc_format)
93
94
95class BasicTests(unittest.TestCase):
96
97    def test_sslwrap_simple(self):
98        # A crude test for the legacy API
99        try:
100            ssl.sslwrap_simple(socket.socket(socket.AF_INET))
101        except IOError, e:
102            if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that
103                pass
104            else:
105                raise
106        try:
107            ssl.sslwrap_simple(socket.socket(socket.AF_INET)._sock)
108        except IOError, e:
109            if e.errno == 32: # broken pipe when ssl_sock.do_handshake(), this test doesn't care about that
110                pass
111            else:
112                raise
113
114
115def can_clear_options():
116    # 0.9.8m or higher
117    return ssl._OPENSSL_API_VERSION >= (0, 9, 8, 13, 15)
118
119def no_sslv2_implies_sslv3_hello():
120    # 0.9.7h or higher
121    return ssl.OPENSSL_VERSION_INFO >= (0, 9, 7, 8, 15)
122
123def have_verify_flags():
124    # 0.9.8 or higher
125    return ssl.OPENSSL_VERSION_INFO >= (0, 9, 8, 0, 15)
126
127def utc_offset(): #NOTE: ignore issues like #1647654
128    # local time = utc time + utc offset
129    if time.daylight and time.localtime().tm_isdst > 0:
130        return -time.altzone  # seconds
131    return -time.timezone
132
133def asn1time(cert_time):
134    # Some versions of OpenSSL ignore seconds, see #18207
135    # 0.9.8.i
136    if ssl._OPENSSL_API_VERSION == (0, 9, 8, 9, 15):
137        fmt = "%b %d %H:%M:%S %Y GMT"
138        dt = datetime.datetime.strptime(cert_time, fmt)
139        dt = dt.replace(second=0)
140        cert_time = dt.strftime(fmt)
141        # %d adds leading zero but ASN1_TIME_print() uses leading space
142        if cert_time[4] == "0":
143            cert_time = cert_time[:4] + " " + cert_time[5:]
144
145    return cert_time
146
147# Issue #9415: Ubuntu hijacks their OpenSSL and forcefully disables SSLv2
148def skip_if_broken_ubuntu_ssl(func):
149    if hasattr(ssl, 'PROTOCOL_SSLv2'):
150        @functools.wraps(func)
151        def f(*args, **kwargs):
152            try:
153                ssl.SSLContext(ssl.PROTOCOL_SSLv2)
154            except ssl.SSLError:
155                if (ssl.OPENSSL_VERSION_INFO == (0, 9, 8, 15, 15) and
156                    platform.linux_distribution() == ('debian', 'squeeze/sid', '')):
157                    raise unittest.SkipTest("Patched Ubuntu OpenSSL breaks behaviour")
158            return func(*args, **kwargs)
159        return f
160    else:
161        return func
162
163def skip_if_openssl_cnf_minprotocol_gt_tls1(func):
164    """Skip a test if the OpenSSL config MinProtocol is > TLSv1.
165    OS distros with an /etc/ssl/openssl.cnf and MinProtocol set often do so to
166    require TLSv1.2 or higher (Debian Buster).  Some of our tests for older
167    protocol versions will fail under such a config.
168    Alternative workaround: Run this test in a process with
169    OPENSSL_CONF=/dev/null in the environment.
170    """
171    @functools.wraps(func)
172    def f(*args, **kwargs):
173        openssl_cnf = os.environ.get("OPENSSL_CONF", "/etc/ssl/openssl.cnf")
174        try:
175            with open(openssl_cnf, "r") as config:
176                for line in config:
177                    match = re.match(r"MinProtocol\s*=\s*(TLSv\d+\S*)", line)
178                    if match:
179                        tls_ver = match.group(1)
180                        if tls_ver > "TLSv1":
181                            raise unittest.SkipTest(
182                                "%s has MinProtocol = %s which is > TLSv1." %
183                                (openssl_cnf, tls_ver))
184        except (EnvironmentError, UnicodeDecodeError) as err:
185            # no config file found, etc.
186            if support.verbose:
187                sys.stdout.write("\n Could not scan %s for MinProtocol: %s\n"
188                                 % (openssl_cnf, err))
189        return func(*args, **kwargs)
190    return f
191
192
193needs_sni = unittest.skipUnless(ssl.HAS_SNI, "SNI support needed for this test")
194
195
196class BasicSocketTests(unittest.TestCase):
197
198    def test_constants(self):
199        ssl.CERT_NONE
200        ssl.CERT_OPTIONAL
201        ssl.CERT_REQUIRED
202        ssl.OP_CIPHER_SERVER_PREFERENCE
203        ssl.OP_SINGLE_DH_USE
204        if ssl.HAS_ECDH:
205            ssl.OP_SINGLE_ECDH_USE
206        if ssl.OPENSSL_VERSION_INFO >= (1, 0):
207            ssl.OP_NO_COMPRESSION
208        self.assertIn(ssl.HAS_SNI, {True, False})
209        self.assertIn(ssl.HAS_ECDH, {True, False})
210        ssl.OP_NO_SSLv2
211        ssl.OP_NO_SSLv3
212        ssl.OP_NO_TLSv1
213        ssl.OP_NO_TLSv1_3
214        if ssl.OPENSSL_VERSION_INFO >= (1, 0, 1):
215            ssl.OP_NO_TLSv1_1
216            ssl.OP_NO_TLSv1_2
217
218    def test_random(self):
219        v = ssl.RAND_status()
220        if support.verbose:
221            sys.stdout.write("\n RAND_status is %d (%s)\n"
222                             % (v, (v and "sufficient randomness") or
223                                "insufficient randomness"))
224        if hasattr(ssl, 'RAND_egd'):
225            self.assertRaises(TypeError, ssl.RAND_egd, 1)
226            self.assertRaises(TypeError, ssl.RAND_egd, 'foo', 1)
227        ssl.RAND_add("this is a random string", 75.0)
228
229    def test_parse_cert(self):
230        # note that this uses an 'unofficial' function in _ssl.c,
231        # provided solely for this test, to exercise the certificate
232        # parsing code
233        p = ssl._ssl._test_decode_cert(CERTFILE)
234        if support.verbose:
235            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
236        self.assertEqual(p['issuer'],
237                         ((('countryName', 'XY'),),
238                          (('localityName', 'Castle Anthrax'),),
239                          (('organizationName', 'Python Software Foundation'),),
240                          (('commonName', 'localhost'),))
241                        )
242        # Note the next three asserts will fail if the keys are regenerated
243        self.assertEqual(p['notAfter'], asn1time('Aug 26 14:23:15 2028 GMT'))
244        self.assertEqual(p['notBefore'], asn1time('Aug 29 14:23:15 2018 GMT'))
245        self.assertEqual(p['serialNumber'], '98A7CF88C74A32ED')
246        self.assertEqual(p['subject'],
247                         ((('countryName', 'XY'),),
248                          (('localityName', 'Castle Anthrax'),),
249                          (('organizationName', 'Python Software Foundation'),),
250                          (('commonName', 'localhost'),))
251                        )
252        self.assertEqual(p['subjectAltName'], (('DNS', 'localhost'),))
253        # Issue #13034: the subjectAltName in some certificates
254        # (notably projects.developer.nokia.com:443) wasn't parsed
255        p = ssl._ssl._test_decode_cert(NOKIACERT)
256        if support.verbose:
257            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
258        self.assertEqual(p['subjectAltName'],
259                         (('DNS', 'projects.developer.nokia.com'),
260                          ('DNS', 'projects.forum.nokia.com'))
261                        )
262        # extra OCSP and AIA fields
263        self.assertEqual(p['OCSP'], ('http://ocsp.verisign.com',))
264        self.assertEqual(p['caIssuers'],
265                         ('http://SVRIntl-G3-aia.verisign.com/SVRIntlG3.cer',))
266        self.assertEqual(p['crlDistributionPoints'],
267                         ('http://SVRIntl-G3-crl.verisign.com/SVRIntlG3.crl',))
268
269    def test_parse_cert_CVE_2019_5010(self):
270        p = ssl._ssl._test_decode_cert(TALOS_INVALID_CRLDP)
271        if support.verbose:
272            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
273        self.assertEqual(
274            p,
275            {
276                'issuer': (
277                    (('countryName', 'UK'),), (('commonName', 'cody-ca'),)),
278                'notAfter': 'Jun 14 18:00:58 2028 GMT',
279                'notBefore': 'Jun 18 18:00:58 2018 GMT',
280                'serialNumber': '02',
281                'subject': ((('countryName', 'UK'),),
282                            (('commonName',
283                              'codenomicon-vm-2.test.lal.cisco.com'),)),
284                'subjectAltName': (
285                    ('DNS', 'codenomicon-vm-2.test.lal.cisco.com'),),
286                'version': 3
287            }
288        )
289
290    def test_parse_cert_CVE_2013_4238(self):
291        p = ssl._ssl._test_decode_cert(NULLBYTECERT)
292        if support.verbose:
293            sys.stdout.write("\n" + pprint.pformat(p) + "\n")
294        subject = ((('countryName', 'US'),),
295                   (('stateOrProvinceName', 'Oregon'),),
296                   (('localityName', 'Beaverton'),),
297                   (('organizationName', 'Python Software Foundation'),),
298                   (('organizationalUnitName', 'Python Core Development'),),
299                   (('commonName', 'null.python.org\x00example.org'),),
300                   (('emailAddress', 'python-dev@python.org'),))
301        self.assertEqual(p['subject'], subject)
302        self.assertEqual(p['issuer'], subject)
303        if ssl._OPENSSL_API_VERSION >= (0, 9, 8):
304            san = (('DNS', 'altnull.python.org\x00example.com'),
305                   ('email', 'null@python.org\x00user@example.org'),
306                   ('URI', 'http://null.python.org\x00http://example.org'),
307                   ('IP Address', '192.0.2.1'),
308                   ('IP Address', '2001:DB8:0:0:0:0:0:1\n'))
309        else:
310            # OpenSSL 0.9.7 doesn't support IPv6 addresses in subjectAltName
311            san = (('DNS', 'altnull.python.org\x00example.com'),
312                   ('email', 'null@python.org\x00user@example.org'),
313                   ('URI', 'http://null.python.org\x00http://example.org'),
314                   ('IP Address', '192.0.2.1'),
315                   ('IP Address', '<invalid>'))
316
317        self.assertEqual(p['subjectAltName'], san)
318
319    def test_parse_all_sans(self):
320        p = ssl._ssl._test_decode_cert(ALLSANFILE)
321        self.assertEqual(p['subjectAltName'],
322            (
323                ('DNS', 'allsans'),
324                ('othername', '<unsupported>'),
325                ('othername', '<unsupported>'),
326                ('email', 'user@example.org'),
327                ('DNS', 'www.example.org'),
328                ('DirName',
329                    ((('countryName', 'XY'),),
330                    (('localityName', 'Castle Anthrax'),),
331                    (('organizationName', 'Python Software Foundation'),),
332                    (('commonName', 'dirname example'),))),
333                ('URI', 'https://www.python.org/'),
334                ('IP Address', '127.0.0.1'),
335                ('IP Address', '0:0:0:0:0:0:0:1\n'),
336                ('Registered ID', '1.2.3.4.5')
337            )
338        )
339
340    def test_DER_to_PEM(self):
341        with open(CAFILE_CACERT, 'r') as f:
342            pem = f.read()
343        d1 = ssl.PEM_cert_to_DER_cert(pem)
344        p2 = ssl.DER_cert_to_PEM_cert(d1)
345        d2 = ssl.PEM_cert_to_DER_cert(p2)
346        self.assertEqual(d1, d2)
347        if not p2.startswith(ssl.PEM_HEADER + '\n'):
348            self.fail("DER-to-PEM didn't include correct header:\n%r\n" % p2)
349        if not p2.endswith('\n' + ssl.PEM_FOOTER + '\n'):
350            self.fail("DER-to-PEM didn't include correct footer:\n%r\n" % p2)
351
352    def test_openssl_version(self):
353        n = ssl.OPENSSL_VERSION_NUMBER
354        t = ssl.OPENSSL_VERSION_INFO
355        s = ssl.OPENSSL_VERSION
356        self.assertIsInstance(n, (int, long))
357        self.assertIsInstance(t, tuple)
358        self.assertIsInstance(s, str)
359        # Some sanity checks follow
360        # >= 0.9
361        self.assertGreaterEqual(n, 0x900000)
362        # < 3.0
363        self.assertLess(n, 0x30000000)
364        major, minor, fix, patch, status = t
365        self.assertGreaterEqual(major, 0)
366        self.assertLess(major, 3)
367        self.assertGreaterEqual(minor, 0)
368        self.assertLess(minor, 256)
369        self.assertGreaterEqual(fix, 0)
370        self.assertLess(fix, 256)
371        self.assertGreaterEqual(patch, 0)
372        self.assertLessEqual(patch, 63)
373        self.assertGreaterEqual(status, 0)
374        self.assertLessEqual(status, 15)
375        # Version string as returned by {Open,Libre}SSL, the format might change
376        if IS_LIBRESSL:
377            self.assertTrue(s.startswith("LibreSSL {:d}".format(major)),
378                            (s, t, hex(n)))
379        else:
380            self.assertTrue(s.startswith("OpenSSL {:d}.{:d}.{:d}".format(major, minor, fix)),
381                            (s, t))
382
383    @support.cpython_only
384    def test_refcycle(self):
385        # Issue #7943: an SSL object doesn't create reference cycles with
386        # itself.
387        s = socket.socket(socket.AF_INET)
388        ss = ssl.wrap_socket(s)
389        wr = weakref.ref(ss)
390        del ss
391        self.assertEqual(wr(), None)
392
393    def test_wrapped_unconnected(self):
394        # Methods on an unconnected SSLSocket propagate the original
395        # socket.error raise by the underlying socket object.
396        s = socket.socket(socket.AF_INET)
397        with closing(ssl.wrap_socket(s)) as ss:
398            self.assertRaises(socket.error, ss.recv, 1)
399            self.assertRaises(socket.error, ss.recv_into, bytearray(b'x'))
400            self.assertRaises(socket.error, ss.recvfrom, 1)
401            self.assertRaises(socket.error, ss.recvfrom_into, bytearray(b'x'), 1)
402            self.assertRaises(socket.error, ss.send, b'x')
403            self.assertRaises(socket.error, ss.sendto, b'x', ('0.0.0.0', 0))
404            self.assertRaises(NotImplementedError, ss.dup)
405
406    def test_timeout(self):
407        # Issue #8524: when creating an SSL socket, the timeout of the
408        # original socket should be retained.
409        for timeout in (None, 0.0, 5.0):
410            s = socket.socket(socket.AF_INET)
411            s.settimeout(timeout)
412            with closing(ssl.wrap_socket(s)) as ss:
413                self.assertEqual(timeout, ss.gettimeout())
414
415    def test_errors(self):
416        sock = socket.socket()
417        self.assertRaisesRegexp(ValueError,
418                        "certfile must be specified",
419                        ssl.wrap_socket, sock, keyfile=CERTFILE)
420        self.assertRaisesRegexp(ValueError,
421                        "certfile must be specified for server-side operations",
422                        ssl.wrap_socket, sock, server_side=True)
423        self.assertRaisesRegexp(ValueError,
424                        "certfile must be specified for server-side operations",
425                        ssl.wrap_socket, sock, server_side=True, certfile="")
426        with closing(ssl.wrap_socket(sock, server_side=True, certfile=CERTFILE)) as s:
427            self.assertRaisesRegexp(ValueError, "can't connect in server-side mode",
428                                    s.connect, (HOST, 8080))
429        with self.assertRaises(IOError) as cm:
430            with closing(socket.socket()) as sock:
431                ssl.wrap_socket(sock, certfile=NONEXISTINGCERT)
432        self.assertEqual(cm.exception.errno, errno.ENOENT)
433        with self.assertRaises(IOError) as cm:
434            with closing(socket.socket()) as sock:
435                ssl.wrap_socket(sock,
436                    certfile=CERTFILE, keyfile=NONEXISTINGCERT)
437        self.assertEqual(cm.exception.errno, errno.ENOENT)
438        with self.assertRaises(IOError) as cm:
439            with closing(socket.socket()) as sock:
440                ssl.wrap_socket(sock,
441                    certfile=NONEXISTINGCERT, keyfile=NONEXISTINGCERT)
442        self.assertEqual(cm.exception.errno, errno.ENOENT)
443
444    def bad_cert_test(self, certfile):
445        """Check that trying to use the given client certificate fails"""
446        certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
447                                   certfile)
448        sock = socket.socket()
449        self.addCleanup(sock.close)
450        with self.assertRaises(ssl.SSLError):
451            ssl.wrap_socket(sock,
452                            certfile=certfile,
453                            ssl_version=ssl.PROTOCOL_TLSv1)
454
455    def test_empty_cert(self):
456        """Wrapping with an empty cert file"""
457        self.bad_cert_test("nullcert.pem")
458
459    def test_malformed_cert(self):
460        """Wrapping with a badly formatted certificate (syntax error)"""
461        self.bad_cert_test("badcert.pem")
462
463    def test_malformed_key(self):
464        """Wrapping with a badly formatted key (syntax error)"""
465        self.bad_cert_test("badkey.pem")
466
467    def test_match_hostname(self):
468        def ok(cert, hostname):
469            ssl.match_hostname(cert, hostname)
470        def fail(cert, hostname):
471            self.assertRaises(ssl.CertificateError,
472                              ssl.match_hostname, cert, hostname)
473
474        cert = {'subject': ((('commonName', 'example.com'),),)}
475        ok(cert, 'example.com')
476        ok(cert, 'ExAmple.cOm')
477        fail(cert, 'www.example.com')
478        fail(cert, '.example.com')
479        fail(cert, 'example.org')
480        fail(cert, 'exampleXcom')
481
482        cert = {'subject': ((('commonName', '*.a.com'),),)}
483        ok(cert, 'foo.a.com')
484        fail(cert, 'bar.foo.a.com')
485        fail(cert, 'a.com')
486        fail(cert, 'Xa.com')
487        fail(cert, '.a.com')
488
489        # only match one left-most wildcard
490        cert = {'subject': ((('commonName', 'f*.com'),),)}
491        ok(cert, 'foo.com')
492        ok(cert, 'f.com')
493        fail(cert, 'bar.com')
494        fail(cert, 'foo.a.com')
495        fail(cert, 'bar.foo.com')
496
497        # NULL bytes are bad, CVE-2013-4073
498        cert = {'subject': ((('commonName',
499                              'null.python.org\x00example.org'),),)}
500        ok(cert, 'null.python.org\x00example.org') # or raise an error?
501        fail(cert, 'example.org')
502        fail(cert, 'null.python.org')
503
504        # error cases with wildcards
505        cert = {'subject': ((('commonName', '*.*.a.com'),),)}
506        fail(cert, 'bar.foo.a.com')
507        fail(cert, 'a.com')
508        fail(cert, 'Xa.com')
509        fail(cert, '.a.com')
510
511        cert = {'subject': ((('commonName', 'a.*.com'),),)}
512        fail(cert, 'a.foo.com')
513        fail(cert, 'a..com')
514        fail(cert, 'a.com')
515
516        # wildcard doesn't match IDNA prefix 'xn--'
517        idna = u'püthon.python.org'.encode("idna").decode("ascii")
518        cert = {'subject': ((('commonName', idna),),)}
519        ok(cert, idna)
520        cert = {'subject': ((('commonName', 'x*.python.org'),),)}
521        fail(cert, idna)
522        cert = {'subject': ((('commonName', 'xn--p*.python.org'),),)}
523        fail(cert, idna)
524
525        # wildcard in first fragment and  IDNA A-labels in sequent fragments
526        # are supported.
527        idna = u'www*.pythön.org'.encode("idna").decode("ascii")
528        cert = {'subject': ((('commonName', idna),),)}
529        ok(cert, u'www.pythön.org'.encode("idna").decode("ascii"))
530        ok(cert, u'www1.pythön.org'.encode("idna").decode("ascii"))
531        fail(cert, u'ftp.pythön.org'.encode("idna").decode("ascii"))
532        fail(cert, u'pythön.org'.encode("idna").decode("ascii"))
533
534        # Slightly fake real-world example
535        cert = {'notAfter': 'Jun 26 21:41:46 2011 GMT',
536                'subject': ((('commonName', 'linuxfrz.org'),),),
537                'subjectAltName': (('DNS', 'linuxfr.org'),
538                                   ('DNS', 'linuxfr.com'),
539                                   ('othername', '<unsupported>'))}
540        ok(cert, 'linuxfr.org')
541        ok(cert, 'linuxfr.com')
542        # Not a "DNS" entry
543        fail(cert, '<unsupported>')
544        # When there is a subjectAltName, commonName isn't used
545        fail(cert, 'linuxfrz.org')
546
547        # A pristine real-world example
548        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
549                'subject': ((('countryName', 'US'),),
550                            (('stateOrProvinceName', 'California'),),
551                            (('localityName', 'Mountain View'),),
552                            (('organizationName', 'Google Inc'),),
553                            (('commonName', 'mail.google.com'),))}
554        ok(cert, 'mail.google.com')
555        fail(cert, 'gmail.com')
556        # Only commonName is considered
557        fail(cert, 'California')
558
559        # Neither commonName nor subjectAltName
560        cert = {'notAfter': 'Dec 18 23:59:59 2011 GMT',
561                'subject': ((('countryName', 'US'),),
562                            (('stateOrProvinceName', 'California'),),
563                            (('localityName', 'Mountain View'),),
564                            (('organizationName', 'Google Inc'),))}
565        fail(cert, 'mail.google.com')
566
567        # No DNS entry in subjectAltName but a commonName
568        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
569                'subject': ((('countryName', 'US'),),
570                            (('stateOrProvinceName', 'California'),),
571                            (('localityName', 'Mountain View'),),
572                            (('commonName', 'mail.google.com'),)),
573                'subjectAltName': (('othername', 'blabla'), )}
574        ok(cert, 'mail.google.com')
575
576        # No DNS entry subjectAltName and no commonName
577        cert = {'notAfter': 'Dec 18 23:59:59 2099 GMT',
578                'subject': ((('countryName', 'US'),),
579                            (('stateOrProvinceName', 'California'),),
580                            (('localityName', 'Mountain View'),),
581                            (('organizationName', 'Google Inc'),)),
582                'subjectAltName': (('othername', 'blabla'),)}
583        fail(cert, 'google.com')
584
585        # Empty cert / no cert
586        self.assertRaises(ValueError, ssl.match_hostname, None, 'example.com')
587        self.assertRaises(ValueError, ssl.match_hostname, {}, 'example.com')
588
589        # Issue #17980: avoid denials of service by refusing more than one
590        # wildcard per fragment.
591        cert = {'subject': ((('commonName', 'a*b.com'),),)}
592        ok(cert, 'axxb.com')
593        cert = {'subject': ((('commonName', 'a*b.co*'),),)}
594        fail(cert, 'axxb.com')
595        cert = {'subject': ((('commonName', 'a*b*.com'),),)}
596        with self.assertRaises(ssl.CertificateError) as cm:
597            ssl.match_hostname(cert, 'axxbxxc.com')
598        self.assertIn("too many wildcards", str(cm.exception))
599
600    def test_server_side(self):
601        # server_hostname doesn't work for server sockets
602        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
603        with closing(socket.socket()) as sock:
604            self.assertRaises(ValueError, ctx.wrap_socket, sock, True,
605                              server_hostname="some.hostname")
606
607    def test_unknown_channel_binding(self):
608        # should raise ValueError for unknown type
609        s = socket.socket(socket.AF_INET)
610        with closing(ssl.wrap_socket(s)) as ss:
611            with self.assertRaises(ValueError):
612                ss.get_channel_binding("unknown-type")
613
614    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
615                         "'tls-unique' channel binding not available")
616    def test_tls_unique_channel_binding(self):
617        # unconnected should return None for known type
618        s = socket.socket(socket.AF_INET)
619        with closing(ssl.wrap_socket(s)) as ss:
620            self.assertIsNone(ss.get_channel_binding("tls-unique"))
621        # the same for server-side
622        s = socket.socket(socket.AF_INET)
623        with closing(ssl.wrap_socket(s, server_side=True, certfile=CERTFILE)) as ss:
624            self.assertIsNone(ss.get_channel_binding("tls-unique"))
625
626    def test_get_default_verify_paths(self):
627        paths = ssl.get_default_verify_paths()
628        self.assertEqual(len(paths), 6)
629        self.assertIsInstance(paths, ssl.DefaultVerifyPaths)
630
631        with support.EnvironmentVarGuard() as env:
632            env["SSL_CERT_DIR"] = CAPATH
633            env["SSL_CERT_FILE"] = CERTFILE
634            paths = ssl.get_default_verify_paths()
635            self.assertEqual(paths.cafile, CERTFILE)
636            self.assertEqual(paths.capath, CAPATH)
637
638    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
639    def test_enum_certificates(self):
640        self.assertTrue(ssl.enum_certificates("CA"))
641        self.assertTrue(ssl.enum_certificates("ROOT"))
642
643        self.assertRaises(TypeError, ssl.enum_certificates)
644        self.assertRaises(WindowsError, ssl.enum_certificates, "")
645
646        trust_oids = set()
647        for storename in ("CA", "ROOT"):
648            store = ssl.enum_certificates(storename)
649            self.assertIsInstance(store, list)
650            for element in store:
651                self.assertIsInstance(element, tuple)
652                self.assertEqual(len(element), 3)
653                cert, enc, trust = element
654                self.assertIsInstance(cert, bytes)
655                self.assertIn(enc, {"x509_asn", "pkcs_7_asn"})
656                self.assertIsInstance(trust, (set, bool))
657                if isinstance(trust, set):
658                    trust_oids.update(trust)
659
660        serverAuth = "1.3.6.1.5.5.7.3.1"
661        self.assertIn(serverAuth, trust_oids)
662
663    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
664    def test_enum_crls(self):
665        self.assertTrue(ssl.enum_crls("CA"))
666        self.assertRaises(TypeError, ssl.enum_crls)
667        self.assertRaises(WindowsError, ssl.enum_crls, "")
668
669        crls = ssl.enum_crls("CA")
670        self.assertIsInstance(crls, list)
671        for element in crls:
672            self.assertIsInstance(element, tuple)
673            self.assertEqual(len(element), 2)
674            self.assertIsInstance(element[0], bytes)
675            self.assertIn(element[1], {"x509_asn", "pkcs_7_asn"})
676
677
678    def test_asn1object(self):
679        expected = (129, 'serverAuth', 'TLS Web Server Authentication',
680                    '1.3.6.1.5.5.7.3.1')
681
682        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
683        self.assertEqual(val, expected)
684        self.assertEqual(val.nid, 129)
685        self.assertEqual(val.shortname, 'serverAuth')
686        self.assertEqual(val.longname, 'TLS Web Server Authentication')
687        self.assertEqual(val.oid, '1.3.6.1.5.5.7.3.1')
688        self.assertIsInstance(val, ssl._ASN1Object)
689        self.assertRaises(ValueError, ssl._ASN1Object, 'serverAuth')
690
691        val = ssl._ASN1Object.fromnid(129)
692        self.assertEqual(val, expected)
693        self.assertIsInstance(val, ssl._ASN1Object)
694        self.assertRaises(ValueError, ssl._ASN1Object.fromnid, -1)
695        with self.assertRaisesRegexp(ValueError, "unknown NID 100000"):
696            ssl._ASN1Object.fromnid(100000)
697        for i in range(1000):
698            try:
699                obj = ssl._ASN1Object.fromnid(i)
700            except ValueError:
701                pass
702            else:
703                self.assertIsInstance(obj.nid, int)
704                self.assertIsInstance(obj.shortname, str)
705                self.assertIsInstance(obj.longname, str)
706                self.assertIsInstance(obj.oid, (str, type(None)))
707
708        val = ssl._ASN1Object.fromname('TLS Web Server Authentication')
709        self.assertEqual(val, expected)
710        self.assertIsInstance(val, ssl._ASN1Object)
711        self.assertEqual(ssl._ASN1Object.fromname('serverAuth'), expected)
712        self.assertEqual(ssl._ASN1Object.fromname('1.3.6.1.5.5.7.3.1'),
713                         expected)
714        with self.assertRaisesRegexp(ValueError, "unknown object 'serverauth'"):
715            ssl._ASN1Object.fromname('serverauth')
716
717    def test_purpose_enum(self):
718        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.1')
719        self.assertIsInstance(ssl.Purpose.SERVER_AUTH, ssl._ASN1Object)
720        self.assertEqual(ssl.Purpose.SERVER_AUTH, val)
721        self.assertEqual(ssl.Purpose.SERVER_AUTH.nid, 129)
722        self.assertEqual(ssl.Purpose.SERVER_AUTH.shortname, 'serverAuth')
723        self.assertEqual(ssl.Purpose.SERVER_AUTH.oid,
724                              '1.3.6.1.5.5.7.3.1')
725
726        val = ssl._ASN1Object('1.3.6.1.5.5.7.3.2')
727        self.assertIsInstance(ssl.Purpose.CLIENT_AUTH, ssl._ASN1Object)
728        self.assertEqual(ssl.Purpose.CLIENT_AUTH, val)
729        self.assertEqual(ssl.Purpose.CLIENT_AUTH.nid, 130)
730        self.assertEqual(ssl.Purpose.CLIENT_AUTH.shortname, 'clientAuth')
731        self.assertEqual(ssl.Purpose.CLIENT_AUTH.oid,
732                              '1.3.6.1.5.5.7.3.2')
733
734    def test_unsupported_dtls(self):
735        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
736        self.addCleanup(s.close)
737        with self.assertRaises(NotImplementedError) as cx:
738            ssl.wrap_socket(s, cert_reqs=ssl.CERT_NONE)
739        self.assertEqual(str(cx.exception), "only stream sockets are supported")
740        ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
741        with self.assertRaises(NotImplementedError) as cx:
742            ctx.wrap_socket(s)
743        self.assertEqual(str(cx.exception), "only stream sockets are supported")
744
745    def cert_time_ok(self, timestring, timestamp):
746        self.assertEqual(ssl.cert_time_to_seconds(timestring), timestamp)
747
748    def cert_time_fail(self, timestring):
749        with self.assertRaises(ValueError):
750            ssl.cert_time_to_seconds(timestring)
751
752    @unittest.skipUnless(utc_offset(),
753                         'local time needs to be different from UTC')
754    def test_cert_time_to_seconds_timezone(self):
755        # Issue #19940: ssl.cert_time_to_seconds() returns wrong
756        #               results if local timezone is not UTC
757        self.cert_time_ok("May  9 00:00:00 2007 GMT", 1178668800.0)
758        self.cert_time_ok("Jan  5 09:34:43 2018 GMT", 1515144883.0)
759
760    def test_cert_time_to_seconds(self):
761        timestring = "Jan  5 09:34:43 2018 GMT"
762        ts = 1515144883.0
763        self.cert_time_ok(timestring, ts)
764        # accept keyword parameter, assert its name
765        self.assertEqual(ssl.cert_time_to_seconds(cert_time=timestring), ts)
766        # accept both %e and %d (space or zero generated by strftime)
767        self.cert_time_ok("Jan 05 09:34:43 2018 GMT", ts)
768        # case-insensitive
769        self.cert_time_ok("JaN  5 09:34:43 2018 GmT", ts)
770        self.cert_time_fail("Jan  5 09:34 2018 GMT")     # no seconds
771        self.cert_time_fail("Jan  5 09:34:43 2018")      # no GMT
772        self.cert_time_fail("Jan  5 09:34:43 2018 UTC")  # not GMT timezone
773        self.cert_time_fail("Jan 35 09:34:43 2018 GMT")  # invalid day
774        self.cert_time_fail("Jon  5 09:34:43 2018 GMT")  # invalid month
775        self.cert_time_fail("Jan  5 24:00:00 2018 GMT")  # invalid hour
776        self.cert_time_fail("Jan  5 09:60:43 2018 GMT")  # invalid minute
777
778        newyear_ts = 1230768000.0
779        # leap seconds
780        self.cert_time_ok("Dec 31 23:59:60 2008 GMT", newyear_ts)
781        # same timestamp
782        self.cert_time_ok("Jan  1 00:00:00 2009 GMT", newyear_ts)
783
784        self.cert_time_ok("Jan  5 09:34:59 2018 GMT", 1515144899)
785        #  allow 60th second (even if it is not a leap second)
786        self.cert_time_ok("Jan  5 09:34:60 2018 GMT", 1515144900)
787        #  allow 2nd leap second for compatibility with time.strptime()
788        self.cert_time_ok("Jan  5 09:34:61 2018 GMT", 1515144901)
789        self.cert_time_fail("Jan  5 09:34:62 2018 GMT")  # invalid seconds
790
791        # no special treatement for the special value:
792        #   99991231235959Z (rfc 5280)
793        self.cert_time_ok("Dec 31 23:59:59 9999 GMT", 253402300799.0)
794
795    @support.run_with_locale('LC_ALL', '')
796    def test_cert_time_to_seconds_locale(self):
797        # `cert_time_to_seconds()` should be locale independent
798
799        def local_february_name():
800            return time.strftime('%b', (1, 2, 3, 4, 5, 6, 0, 0, 0))
801
802        if local_february_name().lower() == 'feb':
803            self.skipTest("locale-specific month name needs to be "
804                          "different from C locale")
805
806        # locale-independent
807        self.cert_time_ok("Feb  9 00:00:00 2007 GMT", 1170979200.0)
808        self.cert_time_fail(local_february_name() + "  9 00:00:00 2007 GMT")
809
810
811class ContextTests(unittest.TestCase):
812
813    @skip_if_broken_ubuntu_ssl
814    def test_constructor(self):
815        for protocol in PROTOCOLS:
816            ssl.SSLContext(protocol)
817        self.assertRaises(TypeError, ssl.SSLContext)
818        self.assertRaises(ValueError, ssl.SSLContext, -1)
819        self.assertRaises(ValueError, ssl.SSLContext, 42)
820
821    @skip_if_broken_ubuntu_ssl
822    def test_protocol(self):
823        for proto in PROTOCOLS:
824            ctx = ssl.SSLContext(proto)
825            self.assertEqual(ctx.protocol, proto)
826
827    def test_ciphers(self):
828        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
829        ctx.set_ciphers("ALL")
830        ctx.set_ciphers("DEFAULT")
831        with self.assertRaisesRegexp(ssl.SSLError, "No cipher can be selected"):
832            ctx.set_ciphers("^$:,;?*'dorothyx")
833
834    @skip_if_broken_ubuntu_ssl
835    def test_options(self):
836        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
837        # OP_ALL | OP_NO_SSLv2 | OP_NO_SSLv3 is the default value
838        default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
839        # SSLContext also enables these by default
840        default |= (OP_NO_COMPRESSION | OP_CIPHER_SERVER_PREFERENCE |
841                    OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE |
842                    OP_ENABLE_MIDDLEBOX_COMPAT)
843        self.assertEqual(default, ctx.options)
844        ctx.options |= ssl.OP_NO_TLSv1
845        self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
846        if can_clear_options():
847            ctx.options = (ctx.options & ~ssl.OP_NO_TLSv1)
848            self.assertEqual(default, ctx.options)
849            ctx.options = 0
850            # Ubuntu has OP_NO_SSLv3 forced on by default
851            self.assertEqual(0, ctx.options & ~ssl.OP_NO_SSLv3)
852        else:
853            with self.assertRaises(ValueError):
854                ctx.options = 0
855
856    def test_verify_mode(self):
857        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
858        # Default value
859        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
860        ctx.verify_mode = ssl.CERT_OPTIONAL
861        self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
862        ctx.verify_mode = ssl.CERT_REQUIRED
863        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
864        ctx.verify_mode = ssl.CERT_NONE
865        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
866        with self.assertRaises(TypeError):
867            ctx.verify_mode = None
868        with self.assertRaises(ValueError):
869            ctx.verify_mode = 42
870
871    @unittest.skipUnless(have_verify_flags(),
872                         "verify_flags need OpenSSL > 0.9.8")
873    def test_verify_flags(self):
874        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
875        # default value
876        tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
877        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT | tf)
878        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF
879        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_LEAF)
880        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_CHAIN
881        self.assertEqual(ctx.verify_flags, ssl.VERIFY_CRL_CHECK_CHAIN)
882        ctx.verify_flags = ssl.VERIFY_DEFAULT
883        self.assertEqual(ctx.verify_flags, ssl.VERIFY_DEFAULT)
884        # supports any value
885        ctx.verify_flags = ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT
886        self.assertEqual(ctx.verify_flags,
887                         ssl.VERIFY_CRL_CHECK_LEAF | ssl.VERIFY_X509_STRICT)
888        with self.assertRaises(TypeError):
889            ctx.verify_flags = None
890
891    def test_load_cert_chain(self):
892        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
893        # Combined key and cert in a single file
894        ctx.load_cert_chain(CERTFILE, keyfile=None)
895        ctx.load_cert_chain(CERTFILE, keyfile=CERTFILE)
896        self.assertRaises(TypeError, ctx.load_cert_chain, keyfile=CERTFILE)
897        with self.assertRaises(IOError) as cm:
898            ctx.load_cert_chain(NONEXISTINGCERT)
899        self.assertEqual(cm.exception.errno, errno.ENOENT)
900        with self.assertRaisesRegexp(ssl.SSLError, "PEM lib"):
901            ctx.load_cert_chain(BADCERT)
902        with self.assertRaisesRegexp(ssl.SSLError, "PEM lib"):
903            ctx.load_cert_chain(EMPTYCERT)
904        # Separate key and cert
905        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
906        ctx.load_cert_chain(ONLYCERT, ONLYKEY)
907        ctx.load_cert_chain(certfile=ONLYCERT, keyfile=ONLYKEY)
908        ctx.load_cert_chain(certfile=BYTES_ONLYCERT, keyfile=BYTES_ONLYKEY)
909        with self.assertRaisesRegexp(ssl.SSLError, "PEM lib"):
910            ctx.load_cert_chain(ONLYCERT)
911        with self.assertRaisesRegexp(ssl.SSLError, "PEM lib"):
912            ctx.load_cert_chain(ONLYKEY)
913        with self.assertRaisesRegexp(ssl.SSLError, "PEM lib"):
914            ctx.load_cert_chain(certfile=ONLYKEY, keyfile=ONLYCERT)
915        # Mismatching key and cert
916        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
917        with self.assertRaisesRegexp(ssl.SSLError, "key values mismatch"):
918            ctx.load_cert_chain(CAFILE_CACERT, ONLYKEY)
919        # Password protected key and cert
920        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD)
921        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode())
922        ctx.load_cert_chain(CERTFILE_PROTECTED,
923                            password=bytearray(KEY_PASSWORD.encode()))
924        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD)
925        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode())
926        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED,
927                            bytearray(KEY_PASSWORD.encode()))
928        with self.assertRaisesRegexp(TypeError, "should be a string"):
929            ctx.load_cert_chain(CERTFILE_PROTECTED, password=True)
930        with self.assertRaises(ssl.SSLError):
931            ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass")
932        with self.assertRaisesRegexp(ValueError, "cannot be longer"):
933            # openssl has a fixed limit on the password buffer.
934            # PEM_BUFSIZE is generally set to 1kb.
935            # Return a string larger than this.
936            ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400)
937        # Password callback
938        def getpass_unicode():
939            return KEY_PASSWORD
940        def getpass_bytes():
941            return KEY_PASSWORD.encode()
942        def getpass_bytearray():
943            return bytearray(KEY_PASSWORD.encode())
944        def getpass_badpass():
945            return "badpass"
946        def getpass_huge():
947            return b'a' * (1024 * 1024)
948        def getpass_bad_type():
949            return 9
950        def getpass_exception():
951            raise Exception('getpass error')
952        class GetPassCallable:
953            def __call__(self):
954                return KEY_PASSWORD
955            def getpass(self):
956                return KEY_PASSWORD
957        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode)
958        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes)
959        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray)
960        ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable())
961        ctx.load_cert_chain(CERTFILE_PROTECTED,
962                            password=GetPassCallable().getpass)
963        with self.assertRaises(ssl.SSLError):
964            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass)
965        with self.assertRaisesRegexp(ValueError, "cannot be longer"):
966            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge)
967        with self.assertRaisesRegexp(TypeError, "must return a string"):
968            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type)
969        with self.assertRaisesRegexp(Exception, "getpass error"):
970            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception)
971        # Make sure the password function isn't called if it isn't needed
972        ctx.load_cert_chain(CERTFILE, password=getpass_exception)
973
974    def test_load_verify_locations(self):
975        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
976        ctx.load_verify_locations(CERTFILE)
977        ctx.load_verify_locations(cafile=CERTFILE, capath=None)
978        ctx.load_verify_locations(BYTES_CERTFILE)
979        ctx.load_verify_locations(cafile=BYTES_CERTFILE, capath=None)
980        ctx.load_verify_locations(cafile=BYTES_CERTFILE.decode('utf-8'))
981        self.assertRaises(TypeError, ctx.load_verify_locations)
982        self.assertRaises(TypeError, ctx.load_verify_locations, None, None, None)
983        with self.assertRaises(IOError) as cm:
984            ctx.load_verify_locations(NONEXISTINGCERT)
985        self.assertEqual(cm.exception.errno, errno.ENOENT)
986        with self.assertRaises(IOError):
987            ctx.load_verify_locations(u'')
988        with self.assertRaisesRegexp(ssl.SSLError, "PEM lib"):
989            ctx.load_verify_locations(BADCERT)
990        ctx.load_verify_locations(CERTFILE, CAPATH)
991        ctx.load_verify_locations(CERTFILE, capath=BYTES_CAPATH)
992
993        # Issue #10989: crash if the second argument type is invalid
994        self.assertRaises(TypeError, ctx.load_verify_locations, None, True)
995
996    def test_load_verify_cadata(self):
997        # test cadata
998        with open(CAFILE_CACERT) as f:
999            cacert_pem = f.read().decode("ascii")
1000        cacert_der = ssl.PEM_cert_to_DER_cert(cacert_pem)
1001        with open(CAFILE_NEURONIO) as f:
1002            neuronio_pem = f.read().decode("ascii")
1003        neuronio_der = ssl.PEM_cert_to_DER_cert(neuronio_pem)
1004
1005        # test PEM
1006        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1007        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 0)
1008        ctx.load_verify_locations(cadata=cacert_pem)
1009        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 1)
1010        ctx.load_verify_locations(cadata=neuronio_pem)
1011        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1012        # cert already in hash table
1013        ctx.load_verify_locations(cadata=neuronio_pem)
1014        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1015
1016        # combined
1017        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1018        combined = "\n".join((cacert_pem, neuronio_pem))
1019        ctx.load_verify_locations(cadata=combined)
1020        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1021
1022        # with junk around the certs
1023        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1024        combined = ["head", cacert_pem, "other", neuronio_pem, "again",
1025                    neuronio_pem, "tail"]
1026        ctx.load_verify_locations(cadata="\n".join(combined))
1027        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1028
1029        # test DER
1030        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1031        ctx.load_verify_locations(cadata=cacert_der)
1032        ctx.load_verify_locations(cadata=neuronio_der)
1033        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1034        # cert already in hash table
1035        ctx.load_verify_locations(cadata=cacert_der)
1036        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1037
1038        # combined
1039        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1040        combined = b"".join((cacert_der, neuronio_der))
1041        ctx.load_verify_locations(cadata=combined)
1042        self.assertEqual(ctx.cert_store_stats()["x509_ca"], 2)
1043
1044        # error cases
1045        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1046        self.assertRaises(TypeError, ctx.load_verify_locations, cadata=object)
1047
1048        with self.assertRaisesRegexp(ssl.SSLError, "no start line"):
1049            ctx.load_verify_locations(cadata=u"broken")
1050        with self.assertRaisesRegexp(ssl.SSLError, "not enough data"):
1051            ctx.load_verify_locations(cadata=b"broken")
1052
1053
1054    def test_load_dh_params(self):
1055        filename = u'dhpäräm.pem'
1056        fs_encoding = sys.getfilesystemencoding()
1057        try:
1058            filename.encode(fs_encoding)
1059        except UnicodeEncodeError:
1060            self.skipTest("filename %r cannot be encoded to the filesystem encoding %r" % (filename, fs_encoding))
1061
1062        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1063        ctx.load_dh_params(DHFILE)
1064        if os.name != 'nt':
1065            ctx.load_dh_params(BYTES_DHFILE)
1066        self.assertRaises(TypeError, ctx.load_dh_params)
1067        self.assertRaises(TypeError, ctx.load_dh_params, None)
1068        with self.assertRaises(IOError) as cm:
1069            ctx.load_dh_params(NONEXISTINGCERT)
1070        self.assertEqual(cm.exception.errno, errno.ENOENT)
1071        with self.assertRaises(ssl.SSLError) as cm:
1072            ctx.load_dh_params(CERTFILE)
1073        with support.temp_dir() as d:
1074            fname = os.path.join(d, filename)
1075            shutil.copy(DHFILE, fname)
1076            ctx.load_dh_params(fname)
1077
1078    @skip_if_broken_ubuntu_ssl
1079    def test_session_stats(self):
1080        for proto in PROTOCOLS:
1081            ctx = ssl.SSLContext(proto)
1082            self.assertEqual(ctx.session_stats(), {
1083                'number': 0,
1084                'connect': 0,
1085                'connect_good': 0,
1086                'connect_renegotiate': 0,
1087                'accept': 0,
1088                'accept_good': 0,
1089                'accept_renegotiate': 0,
1090                'hits': 0,
1091                'misses': 0,
1092                'timeouts': 0,
1093                'cache_full': 0,
1094            })
1095
1096    def test_set_default_verify_paths(self):
1097        # There's not much we can do to test that it acts as expected,
1098        # so just check it doesn't crash or raise an exception.
1099        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1100        ctx.set_default_verify_paths()
1101
1102    @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build")
1103    def test_set_ecdh_curve(self):
1104        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1105        ctx.set_ecdh_curve("prime256v1")
1106        ctx.set_ecdh_curve(b"prime256v1")
1107        self.assertRaises(TypeError, ctx.set_ecdh_curve)
1108        self.assertRaises(TypeError, ctx.set_ecdh_curve, None)
1109        self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo")
1110        self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo")
1111
1112    @needs_sni
1113    def test_sni_callback(self):
1114        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1115
1116        # set_servername_callback expects a callable, or None
1117        self.assertRaises(TypeError, ctx.set_servername_callback)
1118        self.assertRaises(TypeError, ctx.set_servername_callback, 4)
1119        self.assertRaises(TypeError, ctx.set_servername_callback, "")
1120        self.assertRaises(TypeError, ctx.set_servername_callback, ctx)
1121
1122        def dummycallback(sock, servername, ctx):
1123            pass
1124        ctx.set_servername_callback(None)
1125        ctx.set_servername_callback(dummycallback)
1126
1127    @needs_sni
1128    def test_sni_callback_refcycle(self):
1129        # Reference cycles through the servername callback are detected
1130        # and cleared.
1131        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1132        def dummycallback(sock, servername, ctx, cycle=ctx):
1133            pass
1134        ctx.set_servername_callback(dummycallback)
1135        wr = weakref.ref(ctx)
1136        del ctx, dummycallback
1137        gc.collect()
1138        self.assertIs(wr(), None)
1139
1140    def test_cert_store_stats(self):
1141        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1142        self.assertEqual(ctx.cert_store_stats(),
1143            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1144        ctx.load_cert_chain(CERTFILE)
1145        self.assertEqual(ctx.cert_store_stats(),
1146            {'x509_ca': 0, 'crl': 0, 'x509': 0})
1147        ctx.load_verify_locations(CERTFILE)
1148        self.assertEqual(ctx.cert_store_stats(),
1149            {'x509_ca': 0, 'crl': 0, 'x509': 1})
1150        ctx.load_verify_locations(CAFILE_CACERT)
1151        self.assertEqual(ctx.cert_store_stats(),
1152            {'x509_ca': 1, 'crl': 0, 'x509': 2})
1153
1154    def test_get_ca_certs(self):
1155        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1156        self.assertEqual(ctx.get_ca_certs(), [])
1157        # CERTFILE is not flagged as X509v3 Basic Constraints: CA:TRUE
1158        ctx.load_verify_locations(CERTFILE)
1159        self.assertEqual(ctx.get_ca_certs(), [])
1160        # but CAFILE_CACERT is a CA cert
1161        ctx.load_verify_locations(CAFILE_CACERT)
1162        self.assertEqual(ctx.get_ca_certs(),
1163            [{'issuer': ((('organizationName', 'Root CA'),),
1164                         (('organizationalUnitName', 'http://www.cacert.org'),),
1165                         (('commonName', 'CA Cert Signing Authority'),),
1166                         (('emailAddress', 'support@cacert.org'),)),
1167              'notAfter': asn1time('Mar 29 12:29:49 2033 GMT'),
1168              'notBefore': asn1time('Mar 30 12:29:49 2003 GMT'),
1169              'serialNumber': '00',
1170              'crlDistributionPoints': ('https://www.cacert.org/revoke.crl',),
1171              'subject': ((('organizationName', 'Root CA'),),
1172                          (('organizationalUnitName', 'http://www.cacert.org'),),
1173                          (('commonName', 'CA Cert Signing Authority'),),
1174                          (('emailAddress', 'support@cacert.org'),)),
1175              'version': 3}])
1176
1177        with open(CAFILE_CACERT) as f:
1178            pem = f.read()
1179        der = ssl.PEM_cert_to_DER_cert(pem)
1180        self.assertEqual(ctx.get_ca_certs(True), [der])
1181
1182    def test_load_default_certs(self):
1183        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1184        ctx.load_default_certs()
1185
1186        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1187        ctx.load_default_certs(ssl.Purpose.SERVER_AUTH)
1188        ctx.load_default_certs()
1189
1190        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1191        ctx.load_default_certs(ssl.Purpose.CLIENT_AUTH)
1192
1193        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1194        self.assertRaises(TypeError, ctx.load_default_certs, None)
1195        self.assertRaises(TypeError, ctx.load_default_certs, 'SERVER_AUTH')
1196
1197    @unittest.skipIf(sys.platform == "win32", "not-Windows specific")
1198    @unittest.skipIf(IS_LIBRESSL, "LibreSSL doesn't support env vars")
1199    def test_load_default_certs_env(self):
1200        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1201        with support.EnvironmentVarGuard() as env:
1202            env["SSL_CERT_DIR"] = CAPATH
1203            env["SSL_CERT_FILE"] = CERTFILE
1204            ctx.load_default_certs()
1205            self.assertEqual(ctx.cert_store_stats(), {"crl": 0, "x509": 1, "x509_ca": 0})
1206
1207    @unittest.skipUnless(sys.platform == "win32", "Windows specific")
1208    def test_load_default_certs_env_windows(self):
1209        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1210        ctx.load_default_certs()
1211        stats = ctx.cert_store_stats()
1212
1213        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1214        with support.EnvironmentVarGuard() as env:
1215            env["SSL_CERT_DIR"] = CAPATH
1216            env["SSL_CERT_FILE"] = CERTFILE
1217            ctx.load_default_certs()
1218            stats["x509"] += 1
1219            self.assertEqual(ctx.cert_store_stats(), stats)
1220
1221    def _assert_context_options(self, ctx):
1222        self.assertEqual(ctx.options & ssl.OP_NO_SSLv2, ssl.OP_NO_SSLv2)
1223        if OP_NO_COMPRESSION != 0:
1224            self.assertEqual(ctx.options & OP_NO_COMPRESSION,
1225                             OP_NO_COMPRESSION)
1226        if OP_SINGLE_DH_USE != 0:
1227            self.assertEqual(ctx.options & OP_SINGLE_DH_USE,
1228                             OP_SINGLE_DH_USE)
1229        if OP_SINGLE_ECDH_USE != 0:
1230            self.assertEqual(ctx.options & OP_SINGLE_ECDH_USE,
1231                             OP_SINGLE_ECDH_USE)
1232        if OP_CIPHER_SERVER_PREFERENCE != 0:
1233            self.assertEqual(ctx.options & OP_CIPHER_SERVER_PREFERENCE,
1234                             OP_CIPHER_SERVER_PREFERENCE)
1235
1236    def test_create_default_context(self):
1237        ctx = ssl.create_default_context()
1238
1239        self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23)
1240        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1241        self.assertTrue(ctx.check_hostname)
1242        self._assert_context_options(ctx)
1243
1244
1245        with open(SIGNING_CA) as f:
1246            cadata = f.read().decode("ascii")
1247        ctx = ssl.create_default_context(cafile=SIGNING_CA, capath=CAPATH,
1248                                         cadata=cadata)
1249        self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23)
1250        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1251        self._assert_context_options(ctx)
1252
1253        ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
1254        self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23)
1255        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1256        self._assert_context_options(ctx)
1257
1258    def test__create_stdlib_context(self):
1259        ctx = ssl._create_stdlib_context()
1260        self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23)
1261        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1262        self.assertFalse(ctx.check_hostname)
1263        self._assert_context_options(ctx)
1264
1265        ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1)
1266        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1267        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1268        self._assert_context_options(ctx)
1269
1270        ctx = ssl._create_stdlib_context(ssl.PROTOCOL_TLSv1,
1271                                         cert_reqs=ssl.CERT_REQUIRED,
1272                                         check_hostname=True)
1273        self.assertEqual(ctx.protocol, ssl.PROTOCOL_TLSv1)
1274        self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
1275        self.assertTrue(ctx.check_hostname)
1276        self._assert_context_options(ctx)
1277
1278        ctx = ssl._create_stdlib_context(purpose=ssl.Purpose.CLIENT_AUTH)
1279        self.assertEqual(ctx.protocol, ssl.PROTOCOL_SSLv23)
1280        self.assertEqual(ctx.verify_mode, ssl.CERT_NONE)
1281        self._assert_context_options(ctx)
1282
1283    def test__https_verify_certificates(self):
1284        # Unit test to check the contect factory mapping
1285        # The factories themselves are tested above
1286        # This test will fail by design if run under PYTHONHTTPSVERIFY=0
1287        # (as will various test_httplib tests)
1288
1289        # Uses a fresh SSL module to avoid affecting the real one
1290        local_ssl = support.import_fresh_module("ssl")
1291        # Certificate verification is enabled by default
1292        self.assertIs(local_ssl._create_default_https_context,
1293                      local_ssl.create_default_context)
1294        # Turn default verification off
1295        local_ssl._https_verify_certificates(enable=False)
1296        self.assertIs(local_ssl._create_default_https_context,
1297                      local_ssl._create_unverified_context)
1298        # And back on
1299        local_ssl._https_verify_certificates(enable=True)
1300        self.assertIs(local_ssl._create_default_https_context,
1301                      local_ssl.create_default_context)
1302        # The default behaviour is to enable
1303        local_ssl._https_verify_certificates(enable=False)
1304        local_ssl._https_verify_certificates()
1305        self.assertIs(local_ssl._create_default_https_context,
1306                      local_ssl.create_default_context)
1307
1308    def test__https_verify_envvar(self):
1309        # Unit test to check the PYTHONHTTPSVERIFY handling
1310        # Need to use a subprocess so it can still be run under -E
1311        https_is_verified = """import ssl, sys; \
1312            status = "Error: _create_default_https_context does not verify certs" \
1313                       if ssl._create_default_https_context is \
1314                          ssl._create_unverified_context \
1315                     else None; \
1316            sys.exit(status)"""
1317        https_is_not_verified = """import ssl, sys; \
1318            status = "Error: _create_default_https_context verifies certs" \
1319                       if ssl._create_default_https_context is \
1320                          ssl.create_default_context \
1321                     else None; \
1322            sys.exit(status)"""
1323        extra_env = {}
1324        # Omitting it leaves verification on
1325        assert_python_ok("-c", https_is_verified, **extra_env)
1326        # Setting it to zero turns verification off
1327        extra_env[ssl._https_verify_envvar] = "0"
1328        assert_python_ok("-c", https_is_not_verified, **extra_env)
1329        # Any other value should also leave it on
1330        for setting in ("", "1", "enabled", "foo"):
1331            extra_env[ssl._https_verify_envvar] = setting
1332            assert_python_ok("-c", https_is_verified, **extra_env)
1333
1334    def test_check_hostname(self):
1335        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1336        self.assertFalse(ctx.check_hostname)
1337
1338        # Requires CERT_REQUIRED or CERT_OPTIONAL
1339        with self.assertRaises(ValueError):
1340            ctx.check_hostname = True
1341        ctx.verify_mode = ssl.CERT_REQUIRED
1342        self.assertFalse(ctx.check_hostname)
1343        ctx.check_hostname = True
1344        self.assertTrue(ctx.check_hostname)
1345
1346        ctx.verify_mode = ssl.CERT_OPTIONAL
1347        ctx.check_hostname = True
1348        self.assertTrue(ctx.check_hostname)
1349
1350        # Cannot set CERT_NONE with check_hostname enabled
1351        with self.assertRaises(ValueError):
1352            ctx.verify_mode = ssl.CERT_NONE
1353        ctx.check_hostname = False
1354        self.assertFalse(ctx.check_hostname)
1355
1356
1357class SSLErrorTests(unittest.TestCase):
1358
1359    def test_str(self):
1360        # The str() of a SSLError doesn't include the errno
1361        e = ssl.SSLError(1, "foo")
1362        self.assertEqual(str(e), "foo")
1363        self.assertEqual(e.errno, 1)
1364        # Same for a subclass
1365        e = ssl.SSLZeroReturnError(1, "foo")
1366        self.assertEqual(str(e), "foo")
1367        self.assertEqual(e.errno, 1)
1368
1369    def test_lib_reason(self):
1370        # Test the library and reason attributes
1371        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1372        with self.assertRaises(ssl.SSLError) as cm:
1373            ctx.load_dh_params(CERTFILE)
1374        self.assertEqual(cm.exception.library, 'PEM')
1375        self.assertEqual(cm.exception.reason, 'NO_START_LINE')
1376        s = str(cm.exception)
1377        self.assertTrue(s.startswith("[PEM: NO_START_LINE] no start line"), s)
1378
1379    def test_subclass(self):
1380        # Check that the appropriate SSLError subclass is raised
1381        # (this only tests one of them)
1382        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1383        with closing(socket.socket()) as s:
1384            s.bind(("127.0.0.1", 0))
1385            s.listen(5)
1386            c = socket.socket()
1387            c.connect(s.getsockname())
1388            c.setblocking(False)
1389            with closing(ctx.wrap_socket(c, False, do_handshake_on_connect=False)) as c:
1390                with self.assertRaises(ssl.SSLWantReadError) as cm:
1391                    c.do_handshake()
1392                s = str(cm.exception)
1393                self.assertTrue(s.startswith("The operation did not complete (read)"), s)
1394                # For compatibility
1395                self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
1396
1397
1398class NetworkedTests(unittest.TestCase):
1399
1400    def test_connect(self):
1401        with support.transient_internet(REMOTE_HOST):
1402            s = ssl.wrap_socket(socket.socket(socket.AF_INET),
1403                                cert_reqs=ssl.CERT_NONE)
1404            try:
1405                s.connect((REMOTE_HOST, 443))
1406                self.assertEqual({}, s.getpeercert())
1407            finally:
1408                s.close()
1409
1410            # this should fail because we have no verification certs
1411            s = ssl.wrap_socket(socket.socket(socket.AF_INET),
1412                                cert_reqs=ssl.CERT_REQUIRED)
1413            self.assertRaisesRegexp(ssl.SSLError, "certificate verify failed",
1414                                   s.connect, (REMOTE_HOST, 443))
1415            s.close()
1416
1417            # this should succeed because we specify the root cert
1418            s = ssl.wrap_socket(socket.socket(socket.AF_INET),
1419                                cert_reqs=ssl.CERT_REQUIRED,
1420                                ca_certs=REMOTE_ROOT_CERT)
1421            try:
1422                s.connect((REMOTE_HOST, 443))
1423                self.assertTrue(s.getpeercert())
1424            finally:
1425                s.close()
1426
1427    def test_connect_ex(self):
1428        # Issue #11326: check connect_ex() implementation
1429        with support.transient_internet(REMOTE_HOST):
1430            s = ssl.wrap_socket(socket.socket(socket.AF_INET),
1431                                cert_reqs=ssl.CERT_REQUIRED,
1432                                ca_certs=REMOTE_ROOT_CERT)
1433            try:
1434                self.assertEqual(0, s.connect_ex((REMOTE_HOST, 443)))
1435                self.assertTrue(s.getpeercert())
1436            finally:
1437                s.close()
1438
1439    def test_non_blocking_connect_ex(self):
1440        # Issue #11326: non-blocking connect_ex() should allow handshake
1441        # to proceed after the socket gets ready.
1442        with support.transient_internet(REMOTE_HOST):
1443            s = ssl.wrap_socket(socket.socket(socket.AF_INET),
1444                                cert_reqs=ssl.CERT_REQUIRED,
1445                                ca_certs=REMOTE_ROOT_CERT,
1446                                do_handshake_on_connect=False)
1447            try:
1448                s.setblocking(False)
1449                rc = s.connect_ex((REMOTE_HOST, 443))
1450                # EWOULDBLOCK under Windows, EINPROGRESS elsewhere
1451                self.assertIn(rc, (0, errno.EINPROGRESS, errno.EWOULDBLOCK))
1452                # Wait for connect to finish
1453                select.select([], [s], [], 5.0)
1454                # Non-blocking handshake
1455                while True:
1456                    try:
1457                        s.do_handshake()
1458                        break
1459                    except ssl.SSLWantReadError:
1460                        select.select([s], [], [], 5.0)
1461                    except ssl.SSLWantWriteError:
1462                        select.select([], [s], [], 5.0)
1463                # SSL established
1464                self.assertTrue(s.getpeercert())
1465            finally:
1466                s.close()
1467
1468    def test_timeout_connect_ex(self):
1469        # Issue #12065: on a timeout, connect_ex() should return the original
1470        # errno (mimicking the behaviour of non-SSL sockets).
1471        with support.transient_internet(REMOTE_HOST):
1472            s = ssl.wrap_socket(socket.socket(socket.AF_INET),
1473                                cert_reqs=ssl.CERT_REQUIRED,
1474                                ca_certs=REMOTE_ROOT_CERT,
1475                                do_handshake_on_connect=False)
1476            try:
1477                s.settimeout(0.0000001)
1478                rc = s.connect_ex((REMOTE_HOST, 443))
1479                if rc == 0:
1480                    self.skipTest("REMOTE_HOST responded too quickly")
1481                self.assertIn(rc, (errno.EAGAIN, errno.EWOULDBLOCK))
1482            finally:
1483                s.close()
1484
1485    def test_connect_ex_error(self):
1486        with support.transient_internet(REMOTE_HOST):
1487            s = ssl.wrap_socket(socket.socket(socket.AF_INET),
1488                                cert_reqs=ssl.CERT_REQUIRED,
1489                                ca_certs=REMOTE_ROOT_CERT)
1490            try:
1491                rc = s.connect_ex((REMOTE_HOST, 444))
1492                # Issue #19919: Windows machines or VMs hosted on Windows
1493                # machines sometimes return EWOULDBLOCK.
1494                errors = (
1495                    errno.ECONNREFUSED, errno.EHOSTUNREACH, errno.ETIMEDOUT,
1496                    errno.EWOULDBLOCK,
1497                )
1498                self.assertIn(rc, errors)
1499            finally:
1500                s.close()
1501
1502    def test_connect_with_context(self):
1503        with support.transient_internet(REMOTE_HOST):
1504            # Same as test_connect, but with a separately created context
1505            ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1506            s = ctx.wrap_socket(socket.socket(socket.AF_INET))
1507            s.connect((REMOTE_HOST, 443))
1508            try:
1509                self.assertEqual({}, s.getpeercert())
1510            finally:
1511                s.close()
1512            # Same with a server hostname
1513            s = ctx.wrap_socket(socket.socket(socket.AF_INET),
1514                                server_hostname=REMOTE_HOST)
1515            s.connect((REMOTE_HOST, 443))
1516            s.close()
1517            # This should fail because we have no verification certs
1518            ctx.verify_mode = ssl.CERT_REQUIRED
1519            s = ctx.wrap_socket(socket.socket(socket.AF_INET))
1520            self.assertRaisesRegexp(ssl.SSLError, "certificate verify failed",
1521                                    s.connect, (REMOTE_HOST, 443))
1522            s.close()
1523            # This should succeed because we specify the root cert
1524            ctx.load_verify_locations(REMOTE_ROOT_CERT)
1525            s = ctx.wrap_socket(socket.socket(socket.AF_INET))
1526            s.connect((REMOTE_HOST, 443))
1527            try:
1528                cert = s.getpeercert()
1529                self.assertTrue(cert)
1530            finally:
1531                s.close()
1532
1533    def test_connect_capath(self):
1534        # Verify server certificates using the `capath` argument
1535        # NOTE: the subject hashing algorithm has been changed between
1536        # OpenSSL 0.9.8n and 1.0.0, as a result the capath directory must
1537        # contain both versions of each certificate (same content, different
1538        # filename) for this test to be portable across OpenSSL releases.
1539        with support.transient_internet(REMOTE_HOST):
1540            ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1541            ctx.verify_mode = ssl.CERT_REQUIRED
1542            ctx.load_verify_locations(capath=CAPATH)
1543            s = ctx.wrap_socket(socket.socket(socket.AF_INET))
1544            s.connect((REMOTE_HOST, 443))
1545            try:
1546                cert = s.getpeercert()
1547                self.assertTrue(cert)
1548            finally:
1549                s.close()
1550            # Same with a bytes `capath` argument
1551            ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1552            ctx.verify_mode = ssl.CERT_REQUIRED
1553            ctx.load_verify_locations(capath=BYTES_CAPATH)
1554            s = ctx.wrap_socket(socket.socket(socket.AF_INET))
1555            s.connect((REMOTE_HOST, 443))
1556            try:
1557                cert = s.getpeercert()
1558                self.assertTrue(cert)
1559            finally:
1560                s.close()
1561
1562    def test_connect_cadata(self):
1563        with open(REMOTE_ROOT_CERT) as f:
1564            pem = f.read().decode('ascii')
1565        der = ssl.PEM_cert_to_DER_cert(pem)
1566        with support.transient_internet(REMOTE_HOST):
1567            ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1568            ctx.verify_mode = ssl.CERT_REQUIRED
1569            ctx.load_verify_locations(cadata=pem)
1570            with closing(ctx.wrap_socket(socket.socket(socket.AF_INET))) as s:
1571                s.connect((REMOTE_HOST, 443))
1572                cert = s.getpeercert()
1573                self.assertTrue(cert)
1574
1575            # same with DER
1576            ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1577            ctx.verify_mode = ssl.CERT_REQUIRED
1578            ctx.load_verify_locations(cadata=der)
1579            with closing(ctx.wrap_socket(socket.socket(socket.AF_INET))) as s:
1580                s.connect((REMOTE_HOST, 443))
1581                cert = s.getpeercert()
1582                self.assertTrue(cert)
1583
1584    @unittest.skipIf(os.name == "nt", "Can't use a socket as a file under Windows")
1585    def test_makefile_close(self):
1586        # Issue #5238: creating a file-like object with makefile() shouldn't
1587        # delay closing the underlying "real socket" (here tested with its
1588        # file descriptor, hence skipping the test under Windows).
1589        with support.transient_internet(REMOTE_HOST):
1590            ss = ssl.wrap_socket(socket.socket(socket.AF_INET))
1591            ss.connect((REMOTE_HOST, 443))
1592            fd = ss.fileno()
1593            f = ss.makefile()
1594            f.close()
1595            # The fd is still open
1596            os.read(fd, 0)
1597            # Closing the SSL socket should close the fd too
1598            ss.close()
1599            gc.collect()
1600            with self.assertRaises(OSError) as e:
1601                os.read(fd, 0)
1602            self.assertEqual(e.exception.errno, errno.EBADF)
1603
1604    def test_non_blocking_handshake(self):
1605        with support.transient_internet(REMOTE_HOST):
1606            s = socket.socket(socket.AF_INET)
1607            s.connect((REMOTE_HOST, 443))
1608            s.setblocking(False)
1609            s = ssl.wrap_socket(s,
1610                                cert_reqs=ssl.CERT_NONE,
1611                                do_handshake_on_connect=False)
1612            count = 0
1613            while True:
1614                try:
1615                    count += 1
1616                    s.do_handshake()
1617                    break
1618                except ssl.SSLWantReadError:
1619                    select.select([s], [], [])
1620                except ssl.SSLWantWriteError:
1621                    select.select([], [s], [])
1622            s.close()
1623            if support.verbose:
1624                sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
1625
1626    def test_get_server_certificate(self):
1627        def _test_get_server_certificate(host, port, cert=None):
1628            with support.transient_internet(host):
1629                pem = ssl.get_server_certificate((host, port))
1630                if not pem:
1631                    self.fail("No server certificate on %s:%s!" % (host, port))
1632
1633                try:
1634                    pem = ssl.get_server_certificate((host, port),
1635                                                     ca_certs=CERTFILE)
1636                except ssl.SSLError as x:
1637                    #should fail
1638                    if support.verbose:
1639                        sys.stdout.write("%s\n" % x)
1640                else:
1641                    self.fail("Got server certificate %s for %s:%s!" % (pem, host, port))
1642                pem = ssl.get_server_certificate((host, port),
1643                                                 ca_certs=cert)
1644                if not pem:
1645                    self.fail("No server certificate on %s:%s!" % (host, port))
1646                if support.verbose:
1647                    sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem))
1648
1649        _test_get_server_certificate(REMOTE_HOST, 443, REMOTE_ROOT_CERT)
1650        if support.IPV6_ENABLED:
1651            _test_get_server_certificate('ipv6.google.com', 443)
1652
1653    def test_ciphers(self):
1654        remote = (REMOTE_HOST, 443)
1655        with support.transient_internet(remote[0]):
1656            with closing(ssl.wrap_socket(socket.socket(socket.AF_INET),
1657                                         cert_reqs=ssl.CERT_NONE, ciphers="ALL")) as s:
1658                s.connect(remote)
1659            with closing(ssl.wrap_socket(socket.socket(socket.AF_INET),
1660                                         cert_reqs=ssl.CERT_NONE, ciphers="DEFAULT")) as s:
1661                s.connect(remote)
1662            # Error checking can happen at instantiation or when connecting
1663            with self.assertRaisesRegexp(ssl.SSLError, "No cipher can be selected"):
1664                with closing(socket.socket(socket.AF_INET)) as sock:
1665                    s = ssl.wrap_socket(sock,
1666                                        cert_reqs=ssl.CERT_NONE, ciphers="^$:,;?*'dorothyx")
1667                    s.connect(remote)
1668
1669    def test_get_ca_certs_capath(self):
1670        # capath certs are loaded on request
1671        with support.transient_internet(REMOTE_HOST):
1672            ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1673            ctx.verify_mode = ssl.CERT_REQUIRED
1674            ctx.load_verify_locations(capath=CAPATH)
1675            self.assertEqual(ctx.get_ca_certs(), [])
1676            s = ctx.wrap_socket(socket.socket(socket.AF_INET))
1677            s.connect((REMOTE_HOST, 443))
1678            try:
1679                cert = s.getpeercert()
1680                self.assertTrue(cert)
1681            finally:
1682                s.close()
1683            self.assertEqual(len(ctx.get_ca_certs()), 1)
1684
1685    @needs_sni
1686    def test_context_setget(self):
1687        # Check that the context of a connected socket can be replaced.
1688        with support.transient_internet(REMOTE_HOST):
1689            ctx1 = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
1690            ctx2 = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
1691            s = socket.socket(socket.AF_INET)
1692            with closing(ctx1.wrap_socket(s)) as ss:
1693                ss.connect((REMOTE_HOST, 443))
1694                self.assertIs(ss.context, ctx1)
1695                self.assertIs(ss._sslobj.context, ctx1)
1696                ss.context = ctx2
1697                self.assertIs(ss.context, ctx2)
1698                self.assertIs(ss._sslobj.context, ctx2)
1699
1700try:
1701    import threading
1702except ImportError:
1703    _have_threads = False
1704else:
1705    _have_threads = True
1706
1707    from test.ssl_servers import make_https_server
1708
1709    class ThreadedEchoServer(threading.Thread):
1710
1711        class ConnectionHandler(threading.Thread):
1712
1713            """A mildly complicated class, because we want it to work both
1714            with and without the SSL wrapper around the socket connection, so
1715            that we can test the STARTTLS functionality."""
1716
1717            def __init__(self, server, connsock, addr):
1718                self.server = server
1719                self.running = False
1720                self.sock = connsock
1721                self.addr = addr
1722                self.sock.setblocking(1)
1723                self.sslconn = None
1724                threading.Thread.__init__(self)
1725                self.daemon = True
1726
1727            def wrap_conn(self):
1728                try:
1729                    self.sslconn = self.server.context.wrap_socket(
1730                        self.sock, server_side=True)
1731                    self.server.selected_npn_protocols.append(self.sslconn.selected_npn_protocol())
1732                    self.server.selected_alpn_protocols.append(self.sslconn.selected_alpn_protocol())
1733                except (ssl.SSLError, socket.error, OSError) as e:
1734                    if e.errno in (errno.ECONNRESET, errno.EPIPE, errno.ESHUTDOWN):
1735                        # Mimick Python 3:
1736                        #
1737                        #    except (ConnectionResetError, BrokenPipeError):
1738                        #
1739                        # We treat ConnectionResetError as though it were an
1740                        # SSLError - OpenSSL on Ubuntu abruptly closes the
1741                        # connection when asked to use an unsupported protocol.
1742                        #
1743                        # BrokenPipeError is raised in TLS 1.3 mode, when OpenSSL
1744                        # tries to send session tickets after handshake.
1745                        # https://github.com/openssl/openssl/issues/6342
1746                        self.server.conn_errors.append(str(e))
1747                        if self.server.chatty:
1748                            handle_error(
1749                                "\n server:  bad connection attempt from "
1750                                + repr(self.addr) + ":\n")
1751                        self.running = False
1752                        self.close()
1753                        return False
1754                    else:
1755                        # OSError may occur with wrong protocols, e.g. both
1756                        # sides use PROTOCOL_TLS_SERVER.
1757                        #
1758                        # XXX Various errors can have happened here, for example
1759                        # a mismatching protocol version, an invalid certificate,
1760                        # or a low-level bug. This should be made more discriminating.
1761                        if not isinstance(e, ssl.SSLError) and e.errno != errno.ECONNRESET:
1762                            raise
1763                        self.server.conn_errors.append(e)
1764                        if self.server.chatty:
1765                            handle_error("\n server:  bad connection attempt from " + repr(self.addr) + ":\n")
1766                        self.running = False
1767                        self.server.stop()
1768                        self.close()
1769                        return False
1770                else:
1771                    if self.server.context.verify_mode == ssl.CERT_REQUIRED:
1772                        cert = self.sslconn.getpeercert()
1773                        if support.verbose and self.server.chatty:
1774                            sys.stdout.write(" client cert is " + pprint.pformat(cert) + "\n")
1775                        cert_binary = self.sslconn.getpeercert(True)
1776                        if support.verbose and self.server.chatty:
1777                            sys.stdout.write(" cert binary is " + str(len(cert_binary)) + " bytes\n")
1778                    cipher = self.sslconn.cipher()
1779                    if support.verbose and self.server.chatty:
1780                        sys.stdout.write(" server: connection cipher is now " + str(cipher) + "\n")
1781                        sys.stdout.write(" server: selected protocol is now "
1782                                + str(self.sslconn.selected_npn_protocol()) + "\n")
1783                    return True
1784
1785            def read(self):
1786                if self.sslconn:
1787                    return self.sslconn.read()
1788                else:
1789                    return self.sock.recv(1024)
1790
1791            def write(self, bytes):
1792                if self.sslconn:
1793                    return self.sslconn.write(bytes)
1794                else:
1795                    return self.sock.send(bytes)
1796
1797            def close(self):
1798                if self.sslconn:
1799                    self.sslconn.close()
1800                else:
1801                    self.sock.close()
1802
1803            def run(self):
1804                self.running = True
1805                if not self.server.starttls_server:
1806                    if not self.wrap_conn():
1807                        return
1808                while self.running:
1809                    try:
1810                        msg = self.read()
1811                        stripped = msg.strip()
1812                        if not stripped:
1813                            # eof, so quit this handler
1814                            self.running = False
1815                            self.close()
1816                        elif stripped == b'over':
1817                            if support.verbose and self.server.connectionchatty:
1818                                sys.stdout.write(" server: client closed connection\n")
1819                            self.close()
1820                            return
1821                        elif (self.server.starttls_server and
1822                              stripped == b'STARTTLS'):
1823                            if support.verbose and self.server.connectionchatty:
1824                                sys.stdout.write(" server: read STARTTLS from client, sending OK...\n")
1825                            self.write(b"OK\n")
1826                            if not self.wrap_conn():
1827                                return
1828                        elif (self.server.starttls_server and self.sslconn
1829                              and stripped == b'ENDTLS'):
1830                            if support.verbose and self.server.connectionchatty:
1831                                sys.stdout.write(" server: read ENDTLS from client, sending OK...\n")
1832                            self.write(b"OK\n")
1833                            self.sock = self.sslconn.unwrap()
1834                            self.sslconn = None
1835                            if support.verbose and self.server.connectionchatty:
1836                                sys.stdout.write(" server: connection is now unencrypted...\n")
1837                        elif stripped == b'CB tls-unique':
1838                            if support.verbose and self.server.connectionchatty:
1839                                sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
1840                            data = self.sslconn.get_channel_binding("tls-unique")
1841                            self.write(repr(data).encode("us-ascii") + b"\n")
1842                        else:
1843                            if (support.verbose and
1844                                self.server.connectionchatty):
1845                                ctype = (self.sslconn and "encrypted") or "unencrypted"
1846                                sys.stdout.write(" server: read %r (%s), sending back %r (%s)...\n"
1847                                                 % (msg, ctype, msg.lower(), ctype))
1848                            self.write(msg.lower())
1849                    except ssl.SSLError:
1850                        if self.server.chatty:
1851                            handle_error("Test server failure:\n")
1852                        self.close()
1853                        self.running = False
1854                        # normally, we'd just stop here, but for the test
1855                        # harness, we want to stop the server
1856                        self.server.stop()
1857
1858        def __init__(self, certificate=None, ssl_version=None,
1859                     certreqs=None, cacerts=None,
1860                     chatty=True, connectionchatty=False, starttls_server=False,
1861                     npn_protocols=None, alpn_protocols=None,
1862                     ciphers=None, context=None):
1863            if context:
1864                self.context = context
1865            else:
1866                self.context = ssl.SSLContext(ssl_version
1867                                              if ssl_version is not None
1868                                              else ssl.PROTOCOL_TLS)
1869                self.context.verify_mode = (certreqs if certreqs is not None
1870                                            else ssl.CERT_NONE)
1871                if cacerts:
1872                    self.context.load_verify_locations(cacerts)
1873                if certificate:
1874                    self.context.load_cert_chain(certificate)
1875                if npn_protocols:
1876                    self.context.set_npn_protocols(npn_protocols)
1877                if alpn_protocols:
1878                    self.context.set_alpn_protocols(alpn_protocols)
1879                if ciphers:
1880                    self.context.set_ciphers(ciphers)
1881            self.chatty = chatty
1882            self.connectionchatty = connectionchatty
1883            self.starttls_server = starttls_server
1884            self.sock = socket.socket()
1885            self.port = support.bind_port(self.sock)
1886            self.flag = None
1887            self.active = False
1888            self.selected_npn_protocols = []
1889            self.selected_alpn_protocols = []
1890            self.conn_errors = []
1891            threading.Thread.__init__(self)
1892            self.daemon = True
1893
1894        def __enter__(self):
1895            self.start(threading.Event())
1896            self.flag.wait()
1897            return self
1898
1899        def __exit__(self, *args):
1900            self.stop()
1901            self.join()
1902
1903        def start(self, flag=None):
1904            self.flag = flag
1905            threading.Thread.start(self)
1906
1907        def run(self):
1908            self.sock.settimeout(0.05)
1909            self.sock.listen(5)
1910            self.active = True
1911            if self.flag:
1912                # signal an event
1913                self.flag.set()
1914            while self.active:
1915                try:
1916                    newconn, connaddr = self.sock.accept()
1917                    if support.verbose and self.chatty:
1918                        sys.stdout.write(' server:  new connection from '
1919                                         + repr(connaddr) + '\n')
1920                    handler = self.ConnectionHandler(self, newconn, connaddr)
1921                    handler.start()
1922                    handler.join()
1923                except socket.timeout:
1924                    pass
1925                except KeyboardInterrupt:
1926                    self.stop()
1927            self.sock.close()
1928
1929        def stop(self):
1930            self.active = False
1931
1932    class AsyncoreEchoServer(threading.Thread):
1933
1934        class EchoServer(asyncore.dispatcher):
1935
1936            class ConnectionHandler(asyncore.dispatcher_with_send):
1937
1938                def __init__(self, conn, certfile):
1939                    self.socket = ssl.wrap_socket(conn, server_side=True,
1940                                                  certfile=certfile,
1941                                                  do_handshake_on_connect=False)
1942                    asyncore.dispatcher_with_send.__init__(self, self.socket)
1943                    self._ssl_accepting = True
1944                    self._do_ssl_handshake()
1945
1946                def readable(self):
1947                    if isinstance(self.socket, ssl.SSLSocket):
1948                        while self.socket.pending() > 0:
1949                            self.handle_read_event()
1950                    return True
1951
1952                def _do_ssl_handshake(self):
1953                    try:
1954                        self.socket.do_handshake()
1955                    except (ssl.SSLWantReadError, ssl.SSLWantWriteError):
1956                        return
1957                    except ssl.SSLEOFError:
1958                        return self.handle_close()
1959                    except ssl.SSLError:
1960                        raise
1961                    except socket.error, err:
1962                        if err.args[0] == errno.ECONNABORTED:
1963                            return self.handle_close()
1964                    else:
1965                        self._ssl_accepting = False
1966
1967                def handle_read(self):
1968                    if self._ssl_accepting:
1969                        self._do_ssl_handshake()
1970                    else:
1971                        data = self.recv(1024)
1972                        if support.verbose:
1973                            sys.stdout.write(" server:  read %s from client\n" % repr(data))
1974                        if not data:
1975                            self.close()
1976                        else:
1977                            self.send(data.lower())
1978
1979                def handle_close(self):
1980                    self.close()
1981                    if support.verbose:
1982                        sys.stdout.write(" server:  closed connection %s\n" % self.socket)
1983
1984                def handle_error(self):
1985                    raise
1986
1987            def __init__(self, certfile):
1988                self.certfile = certfile
1989                sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1990                self.port = support.bind_port(sock, '')
1991                asyncore.dispatcher.__init__(self, sock)
1992                self.listen(5)
1993
1994            def handle_accept(self):
1995                sock_obj, addr = self.accept()
1996                if support.verbose:
1997                    sys.stdout.write(" server:  new connection from %s:%s\n" %addr)
1998                self.ConnectionHandler(sock_obj, self.certfile)
1999
2000            def handle_error(self):
2001                raise
2002
2003        def __init__(self, certfile):
2004            self.flag = None
2005            self.active = False
2006            self.server = self.EchoServer(certfile)
2007            self.port = self.server.port
2008            threading.Thread.__init__(self)
2009            self.daemon = True
2010
2011        def __str__(self):
2012            return "<%s %s>" % (self.__class__.__name__, self.server)
2013
2014        def __enter__(self):
2015            self.start(threading.Event())
2016            self.flag.wait()
2017            return self
2018
2019        def __exit__(self, *args):
2020            if support.verbose:
2021                sys.stdout.write(" cleanup: stopping server.\n")
2022            self.stop()
2023            if support.verbose:
2024                sys.stdout.write(" cleanup: joining server thread.\n")
2025            self.join()
2026            if support.verbose:
2027                sys.stdout.write(" cleanup: successfully joined.\n")
2028            # make sure that ConnectionHandler is removed from socket_map
2029            asyncore.close_all(ignore_all=True)
2030
2031        def start(self, flag=None):
2032            self.flag = flag
2033            threading.Thread.start(self)
2034
2035        def run(self):
2036            self.active = True
2037            if self.flag:
2038                self.flag.set()
2039            while self.active:
2040                try:
2041                    asyncore.loop(1)
2042                except:
2043                    pass
2044
2045        def stop(self):
2046            self.active = False
2047            self.server.close()
2048
2049    def server_params_test(client_context, server_context, indata=b"FOO\n",
2050                           chatty=True, connectionchatty=False, sni_name=None):
2051        """
2052        Launch a server, connect a client to it and try various reads
2053        and writes.
2054        """
2055        stats = {}
2056        server = ThreadedEchoServer(context=server_context,
2057                                    chatty=chatty,
2058                                    connectionchatty=False)
2059        with server:
2060            with closing(client_context.wrap_socket(socket.socket(),
2061                    server_hostname=sni_name)) as s:
2062                s.connect((HOST, server.port))
2063                for arg in [indata, bytearray(indata), memoryview(indata)]:
2064                    if connectionchatty:
2065                        if support.verbose:
2066                            sys.stdout.write(
2067                                " client:  sending %r...\n" % indata)
2068                    s.write(arg)
2069                    outdata = s.read()
2070                    if connectionchatty:
2071                        if support.verbose:
2072                            sys.stdout.write(" client:  read %r\n" % outdata)
2073                    if outdata != indata.lower():
2074                        raise AssertionError(
2075                            "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
2076                            % (outdata[:20], len(outdata),
2077                               indata[:20].lower(), len(indata)))
2078                s.write(b"over\n")
2079                if connectionchatty:
2080                    if support.verbose:
2081                        sys.stdout.write(" client:  closing connection.\n")
2082                stats.update({
2083                    'compression': s.compression(),
2084                    'cipher': s.cipher(),
2085                    'peercert': s.getpeercert(),
2086                    'client_alpn_protocol': s.selected_alpn_protocol(),
2087                    'client_npn_protocol': s.selected_npn_protocol(),
2088                    'version': s.version(),
2089                })
2090                s.close()
2091            stats['server_alpn_protocols'] = server.selected_alpn_protocols
2092            stats['server_npn_protocols'] = server.selected_npn_protocols
2093        return stats
2094
2095    def try_protocol_combo(server_protocol, client_protocol, expect_success,
2096                           certsreqs=None, server_options=0, client_options=0):
2097        """
2098        Try to SSL-connect using *client_protocol* to *server_protocol*.
2099        If *expect_success* is true, assert that the connection succeeds,
2100        if it's false, assert that the connection fails.
2101        Also, if *expect_success* is a string, assert that it is the protocol
2102        version actually used by the connection.
2103        """
2104        if certsreqs is None:
2105            certsreqs = ssl.CERT_NONE
2106        certtype = {
2107            ssl.CERT_NONE: "CERT_NONE",
2108            ssl.CERT_OPTIONAL: "CERT_OPTIONAL",
2109            ssl.CERT_REQUIRED: "CERT_REQUIRED",
2110        }[certsreqs]
2111        if support.verbose:
2112            formatstr = (expect_success and " %s->%s %s\n") or " {%s->%s} %s\n"
2113            sys.stdout.write(formatstr %
2114                             (ssl.get_protocol_name(client_protocol),
2115                              ssl.get_protocol_name(server_protocol),
2116                              certtype))
2117        client_context = ssl.SSLContext(client_protocol)
2118        client_context.options |= client_options
2119        server_context = ssl.SSLContext(server_protocol)
2120        server_context.options |= server_options
2121
2122        # NOTE: we must enable "ALL" ciphers on the client, otherwise an
2123        # SSLv23 client will send an SSLv3 hello (rather than SSLv2)
2124        # starting from OpenSSL 1.0.0 (see issue #8322).
2125        if client_context.protocol == ssl.PROTOCOL_SSLv23:
2126            client_context.set_ciphers("ALL")
2127
2128        for ctx in (client_context, server_context):
2129            ctx.verify_mode = certsreqs
2130            ctx.load_cert_chain(CERTFILE)
2131            ctx.load_verify_locations(CERTFILE)
2132        try:
2133            stats = server_params_test(client_context, server_context,
2134                                       chatty=False, connectionchatty=False)
2135        # Protocol mismatch can result in either an SSLError, or a
2136        # "Connection reset by peer" error.
2137        except ssl.SSLError:
2138            if expect_success:
2139                raise
2140        except socket.error as e:
2141            if expect_success or e.errno != errno.ECONNRESET:
2142                raise
2143        else:
2144            if not expect_success:
2145                raise AssertionError(
2146                    "Client protocol %s succeeded with server protocol %s!"
2147                    % (ssl.get_protocol_name(client_protocol),
2148                       ssl.get_protocol_name(server_protocol)))
2149            elif (expect_success is not True
2150                  and expect_success != stats['version']):
2151                raise AssertionError("version mismatch: expected %r, got %r"
2152                                     % (expect_success, stats['version']))
2153
2154
2155    class ThreadedTests(unittest.TestCase):
2156
2157        @skip_if_broken_ubuntu_ssl
2158        def test_echo(self):
2159            """Basic test of an SSL client connecting to a server"""
2160            if support.verbose:
2161                sys.stdout.write("\n")
2162            for protocol in PROTOCOLS:
2163                context = ssl.SSLContext(protocol)
2164                context.load_cert_chain(CERTFILE)
2165                server_params_test(context, context,
2166                                   chatty=True, connectionchatty=True)
2167
2168        def test_getpeercert(self):
2169            if support.verbose:
2170                sys.stdout.write("\n")
2171            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
2172            context.verify_mode = ssl.CERT_REQUIRED
2173            context.load_verify_locations(CERTFILE)
2174            context.load_cert_chain(CERTFILE)
2175            server = ThreadedEchoServer(context=context, chatty=False)
2176            with server:
2177                s = context.wrap_socket(socket.socket(),
2178                                        do_handshake_on_connect=False)
2179                s.connect((HOST, server.port))
2180                # getpeercert() raise ValueError while the handshake isn't
2181                # done.
2182                with self.assertRaises(ValueError):
2183                    s.getpeercert()
2184                s.do_handshake()
2185                cert = s.getpeercert()
2186                self.assertTrue(cert, "Can't get peer certificate.")
2187                cipher = s.cipher()
2188                if support.verbose:
2189                    sys.stdout.write(pprint.pformat(cert) + '\n')
2190                    sys.stdout.write("Connection cipher is " + str(cipher) + '.\n')
2191                if 'subject' not in cert:
2192                    self.fail("No subject field in certificate: %s." %
2193                              pprint.pformat(cert))
2194                if ((('organizationName', 'Python Software Foundation'),)
2195                    not in cert['subject']):
2196                    self.fail(
2197                        "Missing or invalid 'organizationName' field in certificate subject; "
2198                        "should be 'Python Software Foundation'.")
2199                self.assertIn('notBefore', cert)
2200                self.assertIn('notAfter', cert)
2201                before = ssl.cert_time_to_seconds(cert['notBefore'])
2202                after = ssl.cert_time_to_seconds(cert['notAfter'])
2203                self.assertLess(before, after)
2204                s.close()
2205
2206        @unittest.skipUnless(have_verify_flags(),
2207                            "verify_flags need OpenSSL > 0.9.8")
2208        def test_crl_check(self):
2209            if support.verbose:
2210                sys.stdout.write("\n")
2211
2212            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
2213            server_context.load_cert_chain(SIGNED_CERTFILE)
2214
2215            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
2216            context.verify_mode = ssl.CERT_REQUIRED
2217            context.load_verify_locations(SIGNING_CA)
2218            tf = getattr(ssl, "VERIFY_X509_TRUSTED_FIRST", 0)
2219            self.assertEqual(context.verify_flags, ssl.VERIFY_DEFAULT | tf)
2220
2221            # VERIFY_DEFAULT should pass
2222            server = ThreadedEchoServer(context=server_context, chatty=True)
2223            with server:
2224                with closing(context.wrap_socket(socket.socket())) as s:
2225                    s.connect((HOST, server.port))
2226                    cert = s.getpeercert()
2227                    self.assertTrue(cert, "Can't get peer certificate.")
2228
2229            # VERIFY_CRL_CHECK_LEAF without a loaded CRL file fails
2230            context.verify_flags |= ssl.VERIFY_CRL_CHECK_LEAF
2231
2232            server = ThreadedEchoServer(context=server_context, chatty=True)
2233            with server:
2234                with closing(context.wrap_socket(socket.socket())) as s:
2235                    with self.assertRaisesRegexp(ssl.SSLError,
2236                                                "certificate verify failed"):
2237                        s.connect((HOST, server.port))
2238
2239            # now load a CRL file. The CRL file is signed by the CA.
2240            context.load_verify_locations(CRLFILE)
2241
2242            server = ThreadedEchoServer(context=server_context, chatty=True)
2243            with server:
2244                with closing(context.wrap_socket(socket.socket())) as s:
2245                    s.connect((HOST, server.port))
2246                    cert = s.getpeercert()
2247                    self.assertTrue(cert, "Can't get peer certificate.")
2248
2249        def test_check_hostname(self):
2250            if support.verbose:
2251                sys.stdout.write("\n")
2252
2253            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
2254            server_context.load_cert_chain(SIGNED_CERTFILE)
2255
2256            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
2257            context.verify_mode = ssl.CERT_REQUIRED
2258            context.check_hostname = True
2259            context.load_verify_locations(SIGNING_CA)
2260
2261            # correct hostname should verify
2262            server = ThreadedEchoServer(context=server_context, chatty=True)
2263            with server:
2264                with closing(context.wrap_socket(socket.socket(),
2265                                                 server_hostname="localhost")) as s:
2266                    s.connect((HOST, server.port))
2267                    cert = s.getpeercert()
2268                    self.assertTrue(cert, "Can't get peer certificate.")
2269
2270            # incorrect hostname should raise an exception
2271            server = ThreadedEchoServer(context=server_context, chatty=True)
2272            with server:
2273                with closing(context.wrap_socket(socket.socket(),
2274                                                 server_hostname="invalid")) as s:
2275                    with self.assertRaisesRegexp(ssl.CertificateError,
2276                                                "hostname 'invalid' doesn't match u?'localhost'"):
2277                        s.connect((HOST, server.port))
2278
2279            # missing server_hostname arg should cause an exception, too
2280            server = ThreadedEchoServer(context=server_context, chatty=True)
2281            with server:
2282                with closing(socket.socket()) as s:
2283                    with self.assertRaisesRegexp(ValueError,
2284                                                "check_hostname requires server_hostname"):
2285                        context.wrap_socket(s)
2286
2287        def test_wrong_cert(self):
2288            """Connecting when the server rejects the client's certificate
2289
2290            Launch a server with CERT_REQUIRED, and check that trying to
2291            connect to it with a wrong client certificate fails.
2292            """
2293            certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
2294                                       "keycert.pem")
2295            server = ThreadedEchoServer(SIGNED_CERTFILE,
2296                                        certreqs=ssl.CERT_REQUIRED,
2297                                        cacerts=SIGNING_CA, chatty=False,
2298                                        connectionchatty=False)
2299            with server, \
2300                    closing(socket.socket()) as sock, \
2301                    closing(ssl.wrap_socket(sock,
2302                                        certfile=certfile,
2303                                        ssl_version=ssl.PROTOCOL_TLSv1)) as s:
2304                try:
2305                    # Expect either an SSL error about the server rejecting
2306                    # the connection, or a low-level connection reset (which
2307                    # sometimes happens on Windows)
2308                    s.connect((HOST, server.port))
2309                except ssl.SSLError as e:
2310                    if support.verbose:
2311                        sys.stdout.write("\nSSLError is %r\n" % e)
2312                except socket.error as e:
2313                    if e.errno != errno.ECONNRESET:
2314                        raise
2315                    if support.verbose:
2316                        sys.stdout.write("\nsocket.error is %r\n" % e)
2317                else:
2318                    self.fail("Use of invalid cert should have failed!")
2319
2320        def test_rude_shutdown(self):
2321            """A brutal shutdown of an SSL server should raise an OSError
2322            in the client when attempting handshake.
2323            """
2324            listener_ready = threading.Event()
2325            listener_gone = threading.Event()
2326
2327            s = socket.socket()
2328            port = support.bind_port(s, HOST)
2329
2330            # `listener` runs in a thread.  It sits in an accept() until
2331            # the main thread connects.  Then it rudely closes the socket,
2332            # and sets Event `listener_gone` to let the main thread know
2333            # the socket is gone.
2334            def listener():
2335                s.listen(5)
2336                listener_ready.set()
2337                newsock, addr = s.accept()
2338                newsock.close()
2339                s.close()
2340                listener_gone.set()
2341
2342            def connector():
2343                listener_ready.wait()
2344                with closing(socket.socket()) as c:
2345                    c.connect((HOST, port))
2346                    listener_gone.wait()
2347                    try:
2348                        ssl_sock = ssl.wrap_socket(c)
2349                    except socket.error:
2350                        pass
2351                    else:
2352                        self.fail('connecting to closed SSL socket should have failed')
2353
2354            t = threading.Thread(target=listener)
2355            t.start()
2356            try:
2357                connector()
2358            finally:
2359                t.join()
2360
2361        @skip_if_broken_ubuntu_ssl
2362        @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'),
2363                             "OpenSSL is compiled without SSLv2 support")
2364        def test_protocol_sslv2(self):
2365            """Connecting to an SSLv2 server with various client options"""
2366            if support.verbose:
2367                sys.stdout.write("\n")
2368            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True)
2369            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_OPTIONAL)
2370            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv2, True, ssl.CERT_REQUIRED)
2371            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False)
2372            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv3, False)
2373            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_TLSv1, False)
2374            # SSLv23 client with specific SSL options
2375            if no_sslv2_implies_sslv3_hello():
2376                # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
2377                try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
2378                                   client_options=ssl.OP_NO_SSLv2)
2379            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
2380                               client_options=ssl.OP_NO_SSLv3)
2381            try_protocol_combo(ssl.PROTOCOL_SSLv2, ssl.PROTOCOL_SSLv23, False,
2382                               client_options=ssl.OP_NO_TLSv1)
2383
2384        @skip_if_broken_ubuntu_ssl
2385        @skip_if_openssl_cnf_minprotocol_gt_tls1
2386        def test_protocol_sslv23(self):
2387            """Connecting to an SSLv23 server with various client options"""
2388            if support.verbose:
2389                sys.stdout.write("\n")
2390            if hasattr(ssl, 'PROTOCOL_SSLv2'):
2391                try:
2392                    try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv2, True)
2393                except socket.error as x:
2394                    # this fails on some older versions of OpenSSL (0.9.7l, for instance)
2395                    if support.verbose:
2396                        sys.stdout.write(
2397                            " SSL2 client to SSL23 server test unexpectedly failed:\n %s\n"
2398                            % str(x))
2399            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2400                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False)
2401            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True)
2402            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1')
2403
2404            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2405                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_OPTIONAL)
2406            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_OPTIONAL)
2407            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
2408
2409            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2410                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False, ssl.CERT_REQUIRED)
2411            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True, ssl.CERT_REQUIRED)
2412            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
2413
2414            # Server with specific SSL options
2415            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2416                try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv3, False,
2417                               server_options=ssl.OP_NO_SSLv3)
2418            # Will choose TLSv1
2419            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23, True,
2420                               server_options=ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
2421            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1, False,
2422                               server_options=ssl.OP_NO_TLSv1)
2423
2424
2425        @skip_if_broken_ubuntu_ssl
2426        @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv3'),
2427                             "OpenSSL is compiled without SSLv3 support")
2428        def test_protocol_sslv3(self):
2429            """Connecting to an SSLv3 server with various client options"""
2430            if support.verbose:
2431                sys.stdout.write("\n")
2432            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3')
2433            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_OPTIONAL)
2434            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv3, 'SSLv3', ssl.CERT_REQUIRED)
2435            if hasattr(ssl, 'PROTOCOL_SSLv2'):
2436                try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv2, False)
2437            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23, False,
2438                               client_options=ssl.OP_NO_SSLv3)
2439            try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_TLSv1, False)
2440            if no_sslv2_implies_sslv3_hello():
2441                # No SSLv2 => client will use an SSLv3 hello on recent OpenSSLs
2442                try_protocol_combo(ssl.PROTOCOL_SSLv3, ssl.PROTOCOL_SSLv23,
2443                                   False, client_options=ssl.OP_NO_SSLv2)
2444
2445        @skip_if_broken_ubuntu_ssl
2446        def test_protocol_tlsv1(self):
2447            """Connecting to a TLSv1 server with various client options"""
2448            if support.verbose:
2449                sys.stdout.write("\n")
2450            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1')
2451            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_OPTIONAL)
2452            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1, 'TLSv1', ssl.CERT_REQUIRED)
2453            if hasattr(ssl, 'PROTOCOL_SSLv2'):
2454                try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv2, False)
2455            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2456                try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv3, False)
2457            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_SSLv23, False,
2458                               client_options=ssl.OP_NO_TLSv1)
2459
2460        @skip_if_broken_ubuntu_ssl
2461        @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_1"),
2462                             "TLS version 1.1 not supported.")
2463        @skip_if_openssl_cnf_minprotocol_gt_tls1
2464        def test_protocol_tlsv1_1(self):
2465            """Connecting to a TLSv1.1 server with various client options.
2466               Testing against older TLS versions."""
2467            if support.verbose:
2468                sys.stdout.write("\n")
2469            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
2470            if hasattr(ssl, 'PROTOCOL_SSLv2'):
2471                try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv2, False)
2472            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2473                try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv3, False)
2474            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_SSLv23, False,
2475                               client_options=ssl.OP_NO_TLSv1_1)
2476
2477            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_1, 'TLSv1.1')
2478            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1, False)
2479            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_1, False)
2480
2481
2482        @skip_if_broken_ubuntu_ssl
2483        @unittest.skipUnless(hasattr(ssl, "PROTOCOL_TLSv1_2"),
2484                             "TLS version 1.2 not supported.")
2485        def test_protocol_tlsv1_2(self):
2486            """Connecting to a TLSv1.2 server with various client options.
2487               Testing against older TLS versions."""
2488            if support.verbose:
2489                sys.stdout.write("\n")
2490            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2',
2491                               server_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,
2492                               client_options=ssl.OP_NO_SSLv3|ssl.OP_NO_SSLv2,)
2493            if hasattr(ssl, 'PROTOCOL_SSLv2'):
2494                try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv2, False)
2495            if hasattr(ssl, 'PROTOCOL_SSLv3'):
2496                try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv3, False)
2497            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_SSLv23, False,
2498                               client_options=ssl.OP_NO_TLSv1_2)
2499
2500            try_protocol_combo(ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_TLSv1_2, 'TLSv1.2')
2501            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1, False)
2502            try_protocol_combo(ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1_2, False)
2503            try_protocol_combo(ssl.PROTOCOL_TLSv1_2, ssl.PROTOCOL_TLSv1_1, False)
2504            try_protocol_combo(ssl.PROTOCOL_TLSv1_1, ssl.PROTOCOL_TLSv1_2, False)
2505
2506        def test_starttls(self):
2507            """Switching from clear text to encrypted and back again."""
2508            msgs = (b"msg 1", b"MSG 2", b"STARTTLS", b"MSG 3", b"msg 4", b"ENDTLS", b"msg 5", b"msg 6")
2509
2510            server = ThreadedEchoServer(CERTFILE,
2511                                        ssl_version=ssl.PROTOCOL_TLSv1,
2512                                        starttls_server=True,
2513                                        chatty=True,
2514                                        connectionchatty=True)
2515            wrapped = False
2516            with server:
2517                s = socket.socket()
2518                s.setblocking(1)
2519                s.connect((HOST, server.port))
2520                if support.verbose:
2521                    sys.stdout.write("\n")
2522                for indata in msgs:
2523                    if support.verbose:
2524                        sys.stdout.write(
2525                            " client:  sending %r...\n" % indata)
2526                    if wrapped:
2527                        conn.write(indata)
2528                        outdata = conn.read()
2529                    else:
2530                        s.send(indata)
2531                        outdata = s.recv(1024)
2532                    msg = outdata.strip().lower()
2533                    if indata == b"STARTTLS" and msg.startswith(b"ok"):
2534                        # STARTTLS ok, switch to secure mode
2535                        if support.verbose:
2536                            sys.stdout.write(
2537                                " client:  read %r from server, starting TLS...\n"
2538                                % msg)
2539                        conn = ssl.wrap_socket(s, ssl_version=ssl.PROTOCOL_TLSv1)
2540                        wrapped = True
2541                    elif indata == b"ENDTLS" and msg.startswith(b"ok"):
2542                        # ENDTLS ok, switch back to clear text
2543                        if support.verbose:
2544                            sys.stdout.write(
2545                                " client:  read %r from server, ending TLS...\n"
2546                                % msg)
2547                        s = conn.unwrap()
2548                        wrapped = False
2549                    else:
2550                        if support.verbose:
2551                            sys.stdout.write(
2552                                " client:  read %r from server\n" % msg)
2553                if support.verbose:
2554                    sys.stdout.write(" client:  closing connection.\n")
2555                if wrapped:
2556                    conn.write(b"over\n")
2557                else:
2558                    s.send(b"over\n")
2559                if wrapped:
2560                    conn.close()
2561                else:
2562                    s.close()
2563
2564        def test_socketserver(self):
2565            """Using a SocketServer to create and manage SSL connections."""
2566            server = make_https_server(self, certfile=CERTFILE)
2567            # try to connect
2568            if support.verbose:
2569                sys.stdout.write('\n')
2570            with open(CERTFILE, 'rb') as f:
2571                d1 = f.read()
2572            d2 = ''
2573            # now fetch the same data from the HTTPS server
2574            url = 'https://localhost:%d/%s' % (
2575                server.port, os.path.split(CERTFILE)[1])
2576            context = ssl.create_default_context(cafile=CERTFILE)
2577            f = urllib2.urlopen(url, context=context)
2578            try:
2579                dlen = f.info().getheader("content-length")
2580                if dlen and (int(dlen) > 0):
2581                    d2 = f.read(int(dlen))
2582                    if support.verbose:
2583                        sys.stdout.write(
2584                            " client: read %d bytes from remote server '%s'\n"
2585                            % (len(d2), server))
2586            finally:
2587                f.close()
2588            self.assertEqual(d1, d2)
2589
2590        def test_asyncore_server(self):
2591            """Check the example asyncore integration."""
2592            if support.verbose:
2593                sys.stdout.write("\n")
2594
2595            indata = b"FOO\n"
2596            server = AsyncoreEchoServer(CERTFILE)
2597            with server:
2598                s = ssl.wrap_socket(socket.socket())
2599                s.connect(('127.0.0.1', server.port))
2600                if support.verbose:
2601                    sys.stdout.write(
2602                        " client:  sending %r...\n" % indata)
2603                s.write(indata)
2604                outdata = s.read()
2605                if support.verbose:
2606                    sys.stdout.write(" client:  read %r\n" % outdata)
2607                if outdata != indata.lower():
2608                    self.fail(
2609                        "bad data <<%r>> (%d) received; expected <<%r>> (%d)\n"
2610                        % (outdata[:20], len(outdata),
2611                           indata[:20].lower(), len(indata)))
2612                s.write(b"over\n")
2613                if support.verbose:
2614                    sys.stdout.write(" client:  closing connection.\n")
2615                s.close()
2616                if support.verbose:
2617                    sys.stdout.write(" client:  connection closed.\n")
2618
2619        def test_recv_send(self):
2620            """Test recv(), send() and friends."""
2621            if support.verbose:
2622                sys.stdout.write("\n")
2623
2624            server = ThreadedEchoServer(CERTFILE,
2625                                        certreqs=ssl.CERT_NONE,
2626                                        ssl_version=ssl.PROTOCOL_TLSv1,
2627                                        cacerts=CERTFILE,
2628                                        chatty=True,
2629                                        connectionchatty=False)
2630            with server:
2631                s = ssl.wrap_socket(socket.socket(),
2632                                    server_side=False,
2633                                    certfile=CERTFILE,
2634                                    ca_certs=CERTFILE,
2635                                    cert_reqs=ssl.CERT_NONE,
2636                                    ssl_version=ssl.PROTOCOL_TLSv1)
2637                s.connect((HOST, server.port))
2638                # helper methods for standardising recv* method signatures
2639                def _recv_into():
2640                    b = bytearray(b"\0"*100)
2641                    count = s.recv_into(b)
2642                    return b[:count]
2643
2644                def _recvfrom_into():
2645                    b = bytearray(b"\0"*100)
2646                    count, addr = s.recvfrom_into(b)
2647                    return b[:count]
2648
2649                # (name, method, whether to expect success, *args)
2650                send_methods = [
2651                    ('send', s.send, True, []),
2652                    ('sendto', s.sendto, False, ["some.address"]),
2653                    ('sendall', s.sendall, True, []),
2654                ]
2655                recv_methods = [
2656                    ('recv', s.recv, True, []),
2657                    ('recvfrom', s.recvfrom, False, ["some.address"]),
2658                    ('recv_into', _recv_into, True, []),
2659                    ('recvfrom_into', _recvfrom_into, False, []),
2660                ]
2661                data_prefix = u"PREFIX_"
2662
2663                for meth_name, send_meth, expect_success, args in send_methods:
2664                    indata = (data_prefix + meth_name).encode('ascii')
2665                    try:
2666                        send_meth(indata, *args)
2667                        outdata = s.read()
2668                        if outdata != indata.lower():
2669                            self.fail(
2670                                "While sending with <<{name:s}>> bad data "
2671                                "<<{outdata:r}>> ({nout:d}) received; "
2672                                "expected <<{indata:r}>> ({nin:d})\n".format(
2673                                    name=meth_name, outdata=outdata[:20],
2674                                    nout=len(outdata),
2675                                    indata=indata[:20], nin=len(indata)
2676                                )
2677                            )
2678                    except ValueError as e:
2679                        if expect_success:
2680                            self.fail(
2681                                "Failed to send with method <<{name:s}>>; "
2682                                "expected to succeed.\n".format(name=meth_name)
2683                            )
2684                        if not str(e).startswith(meth_name):
2685                            self.fail(
2686                                "Method <<{name:s}>> failed with unexpected "
2687                                "exception message: {exp:s}\n".format(
2688                                    name=meth_name, exp=e
2689                                )
2690                            )
2691
2692                for meth_name, recv_meth, expect_success, args in recv_methods:
2693                    indata = (data_prefix + meth_name).encode('ascii')
2694                    try:
2695                        s.send(indata)
2696                        outdata = recv_meth(*args)
2697                        if outdata != indata.lower():
2698                            self.fail(
2699                                "While receiving with <<{name:s}>> bad data "
2700                                "<<{outdata:r}>> ({nout:d}) received; "
2701                                "expected <<{indata:r}>> ({nin:d})\n".format(
2702                                    name=meth_name, outdata=outdata[:20],
2703                                    nout=len(outdata),
2704                                    indata=indata[:20], nin=len(indata)
2705                                )
2706                            )
2707                    except ValueError as e:
2708                        if expect_success:
2709                            self.fail(
2710                                "Failed to receive with method <<{name:s}>>; "
2711                                "expected to succeed.\n".format(name=meth_name)
2712                            )
2713                        if not str(e).startswith(meth_name):
2714                            self.fail(
2715                                "Method <<{name:s}>> failed with unexpected "
2716                                "exception message: {exp:s}\n".format(
2717                                    name=meth_name, exp=e
2718                                )
2719                            )
2720                        # consume data
2721                        s.read()
2722
2723                # read(-1, buffer) is supported, even though read(-1) is not
2724                data = b"data"
2725                s.send(data)
2726                buffer = bytearray(len(data))
2727                self.assertEqual(s.read(-1, buffer), len(data))
2728                self.assertEqual(buffer, data)
2729
2730                self.assertRaises(NotImplementedError, s.dup)
2731                s.write(b"over\n")
2732
2733                self.assertRaises(ValueError, s.recv, -1)
2734                self.assertRaises(ValueError, s.read, -1)
2735
2736                s.close()
2737
2738        def test_recv_zero(self):
2739            server = ThreadedEchoServer(CERTFILE)
2740            server.__enter__()
2741            self.addCleanup(server.__exit__, None, None)
2742            s = socket.create_connection((HOST, server.port))
2743            self.addCleanup(s.close)
2744            s = ssl.wrap_socket(s, suppress_ragged_eofs=False)
2745            self.addCleanup(s.close)
2746
2747            # recv/read(0) should return no data
2748            s.send(b"data")
2749            self.assertEqual(s.recv(0), b"")
2750            self.assertEqual(s.read(0), b"")
2751            self.assertEqual(s.read(), b"data")
2752
2753            # Should not block if the other end sends no data
2754            s.setblocking(False)
2755            self.assertEqual(s.recv(0), b"")
2756            self.assertEqual(s.recv_into(bytearray()), 0)
2757
2758        def test_handshake_timeout(self):
2759            # Issue #5103: SSL handshake must respect the socket timeout
2760            server = socket.socket(socket.AF_INET)
2761            host = "127.0.0.1"
2762            port = support.bind_port(server)
2763            started = threading.Event()
2764            finish = False
2765
2766            def serve():
2767                server.listen(5)
2768                started.set()
2769                conns = []
2770                while not finish:
2771                    r, w, e = select.select([server], [], [], 0.1)
2772                    if server in r:
2773                        # Let the socket hang around rather than having
2774                        # it closed by garbage collection.
2775                        conns.append(server.accept()[0])
2776                for sock in conns:
2777                    sock.close()
2778
2779            t = threading.Thread(target=serve)
2780            t.start()
2781            started.wait()
2782
2783            try:
2784                try:
2785                    c = socket.socket(socket.AF_INET)
2786                    c.settimeout(0.2)
2787                    c.connect((host, port))
2788                    # Will attempt handshake and time out
2789                    self.assertRaisesRegexp(ssl.SSLError, "timed out",
2790                                            ssl.wrap_socket, c)
2791                finally:
2792                    c.close()
2793                try:
2794                    c = socket.socket(socket.AF_INET)
2795                    c = ssl.wrap_socket(c)
2796                    c.settimeout(0.2)
2797                    # Will attempt handshake and time out
2798                    self.assertRaisesRegexp(ssl.SSLError, "timed out",
2799                                            c.connect, (host, port))
2800                finally:
2801                    c.close()
2802            finally:
2803                finish = True
2804                t.join()
2805                server.close()
2806
2807        def test_server_accept(self):
2808            # Issue #16357: accept() on a SSLSocket created through
2809            # SSLContext.wrap_socket().
2810            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
2811            context.verify_mode = ssl.CERT_REQUIRED
2812            context.load_verify_locations(CERTFILE)
2813            context.load_cert_chain(CERTFILE)
2814            server = socket.socket(socket.AF_INET)
2815            host = "127.0.0.1"
2816            port = support.bind_port(server)
2817            server = context.wrap_socket(server, server_side=True)
2818
2819            evt = threading.Event()
2820            remote = [None]
2821            peer = [None]
2822            def serve():
2823                server.listen(5)
2824                # Block on the accept and wait on the connection to close.
2825                evt.set()
2826                remote[0], peer[0] = server.accept()
2827                remote[0].send(remote[0].recv(4))
2828
2829            t = threading.Thread(target=serve)
2830            t.start()
2831            # Client wait until server setup and perform a connect.
2832            evt.wait()
2833            client = context.wrap_socket(socket.socket())
2834            client.connect((host, port))
2835            client.send(b'data')
2836            client.recv()
2837            client_addr = client.getsockname()
2838            client.close()
2839            t.join()
2840            remote[0].close()
2841            server.close()
2842            # Sanity checks.
2843            self.assertIsInstance(remote[0], ssl.SSLSocket)
2844            self.assertEqual(peer[0], client_addr)
2845
2846        def test_getpeercert_enotconn(self):
2847            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
2848            with closing(context.wrap_socket(socket.socket())) as sock:
2849                with self.assertRaises(socket.error) as cm:
2850                    sock.getpeercert()
2851                self.assertEqual(cm.exception.errno, errno.ENOTCONN)
2852
2853        def test_do_handshake_enotconn(self):
2854            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
2855            with closing(context.wrap_socket(socket.socket())) as sock:
2856                with self.assertRaises(socket.error) as cm:
2857                    sock.do_handshake()
2858                self.assertEqual(cm.exception.errno, errno.ENOTCONN)
2859
2860        def test_no_shared_ciphers(self):
2861            server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
2862            server_context.load_cert_chain(SIGNED_CERTFILE)
2863            client_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
2864            client_context.verify_mode = ssl.CERT_REQUIRED
2865            client_context.check_hostname = True
2866
2867            # OpenSSL enables all TLS 1.3 ciphers, enforce TLS 1.2 for test
2868            client_context.options |= ssl.OP_NO_TLSv1_3
2869            # Force different suites on client and master
2870            client_context.set_ciphers("AES128")
2871            server_context.set_ciphers("AES256")
2872            with ThreadedEchoServer(context=server_context) as server:
2873                s = client_context.wrap_socket(
2874                        socket.socket(),
2875                        server_hostname="localhost")
2876                with self.assertRaises(ssl.SSLError):
2877                    s.connect((HOST, server.port))
2878            self.assertIn("no shared cipher", str(server.conn_errors[0]))
2879
2880        def test_version_basic(self):
2881            """
2882            Basic tests for SSLSocket.version().
2883            More tests are done in the test_protocol_*() methods.
2884            """
2885            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
2886            with ThreadedEchoServer(CERTFILE,
2887                                    ssl_version=ssl.PROTOCOL_TLSv1,
2888                                    chatty=False) as server:
2889                with closing(context.wrap_socket(socket.socket())) as s:
2890                    self.assertIs(s.version(), None)
2891                    s.connect((HOST, server.port))
2892                    self.assertEqual(s.version(), 'TLSv1')
2893                self.assertIs(s.version(), None)
2894
2895        @unittest.skipUnless(ssl.HAS_TLSv1_3,
2896                             "test requires TLSv1.3 enabled OpenSSL")
2897        def test_tls1_3(self):
2898            context = ssl.SSLContext(ssl.PROTOCOL_TLS)
2899            context.load_cert_chain(CERTFILE)
2900            # disable all but TLS 1.3
2901            context.options |= (
2902                ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
2903            )
2904            with ThreadedEchoServer(context=context) as server:
2905                s = context.wrap_socket(socket.socket())
2906                with closing(s):
2907                    s.connect((HOST, server.port))
2908                    self.assertIn(s.cipher()[0], [
2909                        'TLS_AES_256_GCM_SHA384',
2910                        'TLS_CHACHA20_POLY1305_SHA256',
2911                        'TLS_AES_128_GCM_SHA256',
2912                    ])
2913
2914        @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
2915        def test_default_ecdh_curve(self):
2916            # Issue #21015: elliptic curve-based Diffie Hellman key exchange
2917            # should be enabled by default on SSL contexts.
2918            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
2919            context.load_cert_chain(CERTFILE)
2920            # TLSv1.3 defaults to PFS key agreement and no longer has KEA in
2921            # cipher name.
2922            context.options |= ssl.OP_NO_TLSv1_3
2923            # Prior to OpenSSL 1.0.0, ECDH ciphers have to be enabled
2924            # explicitly using the 'ECCdraft' cipher alias.  Otherwise,
2925            # our default cipher list should prefer ECDH-based ciphers
2926            # automatically.
2927            if ssl.OPENSSL_VERSION_INFO < (1, 0, 0):
2928                context.set_ciphers("ECCdraft:ECDH")
2929            with ThreadedEchoServer(context=context) as server:
2930                with closing(context.wrap_socket(socket.socket())) as s:
2931                    s.connect((HOST, server.port))
2932                    self.assertIn("ECDH", s.cipher()[0])
2933
2934        @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
2935                             "'tls-unique' channel binding not available")
2936        def test_tls_unique_channel_binding(self):
2937            """Test tls-unique channel binding."""
2938            if support.verbose:
2939                sys.stdout.write("\n")
2940
2941            server = ThreadedEchoServer(CERTFILE,
2942                                        certreqs=ssl.CERT_NONE,
2943                                        ssl_version=ssl.PROTOCOL_TLSv1,
2944                                        cacerts=CERTFILE,
2945                                        chatty=True,
2946                                        connectionchatty=False)
2947            with server:
2948                s = ssl.wrap_socket(socket.socket(),
2949                                    server_side=False,
2950                                    certfile=CERTFILE,
2951                                    ca_certs=CERTFILE,
2952                                    cert_reqs=ssl.CERT_NONE,
2953                                    ssl_version=ssl.PROTOCOL_TLSv1)
2954                s.connect((HOST, server.port))
2955                # get the data
2956                cb_data = s.get_channel_binding("tls-unique")
2957                if support.verbose:
2958                    sys.stdout.write(" got channel binding data: {0!r}\n"
2959                                     .format(cb_data))
2960
2961                # check if it is sane
2962                self.assertIsNotNone(cb_data)
2963                self.assertEqual(len(cb_data), 12) # True for TLSv1
2964
2965                # and compare with the peers version
2966                s.write(b"CB tls-unique\n")
2967                peer_data_repr = s.read().strip()
2968                self.assertEqual(peer_data_repr,
2969                                 repr(cb_data).encode("us-ascii"))
2970                s.close()
2971
2972                # now, again
2973                s = ssl.wrap_socket(socket.socket(),
2974                                    server_side=False,
2975                                    certfile=CERTFILE,
2976                                    ca_certs=CERTFILE,
2977                                    cert_reqs=ssl.CERT_NONE,
2978                                    ssl_version=ssl.PROTOCOL_TLSv1)
2979                s.connect((HOST, server.port))
2980                new_cb_data = s.get_channel_binding("tls-unique")
2981                if support.verbose:
2982                    sys.stdout.write(" got another channel binding data: {0!r}\n"
2983                                     .format(new_cb_data))
2984                # is it really unique
2985                self.assertNotEqual(cb_data, new_cb_data)
2986                self.assertIsNotNone(cb_data)
2987                self.assertEqual(len(cb_data), 12) # True for TLSv1
2988                s.write(b"CB tls-unique\n")
2989                peer_data_repr = s.read().strip()
2990                self.assertEqual(peer_data_repr,
2991                                 repr(new_cb_data).encode("us-ascii"))
2992                s.close()
2993
2994        def test_compression(self):
2995            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
2996            context.load_cert_chain(CERTFILE)
2997            stats = server_params_test(context, context,
2998                                       chatty=True, connectionchatty=True)
2999            if support.verbose:
3000                sys.stdout.write(" got compression: {!r}\n".format(stats['compression']))
3001            self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' })
3002
3003        @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'),
3004                             "ssl.OP_NO_COMPRESSION needed for this test")
3005        def test_compression_disabled(self):
3006            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3007            context.load_cert_chain(CERTFILE)
3008            context.options |= ssl.OP_NO_COMPRESSION
3009            stats = server_params_test(context, context,
3010                                       chatty=True, connectionchatty=True)
3011            self.assertIs(stats['compression'], None)
3012
3013        def test_dh_params(self):
3014            # Check we can get a connection with ephemeral Diffie-Hellman
3015            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3016            context.load_cert_chain(CERTFILE)
3017            context.load_dh_params(DHFILE)
3018            context.set_ciphers("kEDH")
3019            stats = server_params_test(context, context,
3020                                       chatty=True, connectionchatty=True)
3021            cipher = stats["cipher"][0]
3022            parts = cipher.split("-")
3023            if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts:
3024                self.fail("Non-DH cipher: " + cipher[0])
3025
3026        def test_selected_alpn_protocol(self):
3027            # selected_alpn_protocol() is None unless ALPN is used.
3028            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3029            context.load_cert_chain(CERTFILE)
3030            stats = server_params_test(context, context,
3031                                       chatty=True, connectionchatty=True)
3032            self.assertIs(stats['client_alpn_protocol'], None)
3033
3034        @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support required")
3035        def test_selected_alpn_protocol_if_server_uses_alpn(self):
3036            # selected_alpn_protocol() is None unless ALPN is used by the client.
3037            client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3038            client_context.load_verify_locations(CERTFILE)
3039            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3040            server_context.load_cert_chain(CERTFILE)
3041            server_context.set_alpn_protocols(['foo', 'bar'])
3042            stats = server_params_test(client_context, server_context,
3043                                       chatty=True, connectionchatty=True)
3044            self.assertIs(stats['client_alpn_protocol'], None)
3045
3046        @unittest.skipUnless(ssl.HAS_ALPN, "ALPN support needed for this test")
3047        def test_alpn_protocols(self):
3048            server_protocols = ['foo', 'bar', 'milkshake']
3049            protocol_tests = [
3050                (['foo', 'bar'], 'foo'),
3051                (['bar', 'foo'], 'foo'),
3052                (['milkshake'], 'milkshake'),
3053                (['http/3.0', 'http/4.0'], None)
3054            ]
3055            for client_protocols, expected in protocol_tests:
3056                server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
3057                server_context.load_cert_chain(CERTFILE)
3058                server_context.set_alpn_protocols(server_protocols)
3059                client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
3060                client_context.load_cert_chain(CERTFILE)
3061                client_context.set_alpn_protocols(client_protocols)
3062
3063                try:
3064                    stats = server_params_test(client_context,
3065                                               server_context,
3066                                               chatty=True,
3067                                               connectionchatty=True)
3068                except ssl.SSLError as e:
3069                    stats = e
3070
3071                if (expected is None and IS_OPENSSL_1_1
3072                        and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)):
3073                    # OpenSSL 1.1.0 to 1.1.0e raises handshake error
3074                    self.assertIsInstance(stats, ssl.SSLError)
3075                else:
3076                    msg = "failed trying %s (s) and %s (c).\n" \
3077                        "was expecting %s, but got %%s from the %%s" \
3078                            % (str(server_protocols), str(client_protocols),
3079                                str(expected))
3080                    client_result = stats['client_alpn_protocol']
3081                    self.assertEqual(client_result, expected,
3082                                     msg % (client_result, "client"))
3083                    server_result = stats['server_alpn_protocols'][-1] \
3084                        if len(stats['server_alpn_protocols']) else 'nothing'
3085                    self.assertEqual(server_result, expected,
3086                                     msg % (server_result, "server"))
3087
3088        def test_selected_npn_protocol(self):
3089            # selected_npn_protocol() is None unless NPN is used
3090            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3091            context.load_cert_chain(CERTFILE)
3092            stats = server_params_test(context, context,
3093                                       chatty=True, connectionchatty=True)
3094            self.assertIs(stats['client_npn_protocol'], None)
3095
3096        @unittest.skipUnless(ssl.HAS_NPN, "NPN support needed for this test")
3097        def test_npn_protocols(self):
3098            server_protocols = ['http/1.1', 'spdy/2']
3099            protocol_tests = [
3100                (['http/1.1', 'spdy/2'], 'http/1.1'),
3101                (['spdy/2', 'http/1.1'], 'http/1.1'),
3102                (['spdy/2', 'test'], 'spdy/2'),
3103                (['abc', 'def'], 'abc')
3104            ]
3105            for client_protocols, expected in protocol_tests:
3106                server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3107                server_context.load_cert_chain(CERTFILE)
3108                server_context.set_npn_protocols(server_protocols)
3109                client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3110                client_context.load_cert_chain(CERTFILE)
3111                client_context.set_npn_protocols(client_protocols)
3112                stats = server_params_test(client_context, server_context,
3113                                           chatty=True, connectionchatty=True)
3114
3115                msg = "failed trying %s (s) and %s (c).\n" \
3116                      "was expecting %s, but got %%s from the %%s" \
3117                          % (str(server_protocols), str(client_protocols),
3118                             str(expected))
3119                client_result = stats['client_npn_protocol']
3120                self.assertEqual(client_result, expected, msg % (client_result, "client"))
3121                server_result = stats['server_npn_protocols'][-1] \
3122                    if len(stats['server_npn_protocols']) else 'nothing'
3123                self.assertEqual(server_result, expected, msg % (server_result, "server"))
3124
3125        def sni_contexts(self):
3126            server_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3127            server_context.load_cert_chain(SIGNED_CERTFILE)
3128            other_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3129            other_context.load_cert_chain(SIGNED_CERTFILE2)
3130            client_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
3131            client_context.verify_mode = ssl.CERT_REQUIRED
3132            client_context.load_verify_locations(SIGNING_CA)
3133            return server_context, other_context, client_context
3134
3135        def check_common_name(self, stats, name):
3136            cert = stats['peercert']
3137            self.assertIn((('commonName', name),), cert['subject'])
3138
3139        @needs_sni
3140        def test_sni_callback(self):
3141            calls = []
3142            server_context, other_context, client_context = self.sni_contexts()
3143
3144            def servername_cb(ssl_sock, server_name, initial_context):
3145                calls.append((server_name, initial_context))
3146                if server_name is not None:
3147                    ssl_sock.context = other_context
3148            server_context.set_servername_callback(servername_cb)
3149
3150            stats = server_params_test(client_context, server_context,
3151                                       chatty=True,
3152                                       sni_name='supermessage')
3153            # The hostname was fetched properly, and the certificate was
3154            # changed for the connection.
3155            self.assertEqual(calls, [("supermessage", server_context)])
3156            # CERTFILE4 was selected
3157            self.check_common_name(stats, 'fakehostname')
3158
3159            calls = []
3160            # The callback is called with server_name=None
3161            stats = server_params_test(client_context, server_context,
3162                                       chatty=True,
3163                                       sni_name=None)
3164            self.assertEqual(calls, [(None, server_context)])
3165            self.check_common_name(stats, 'localhost')
3166
3167            # Check disabling the callback
3168            calls = []
3169            server_context.set_servername_callback(None)
3170
3171            stats = server_params_test(client_context, server_context,
3172                                       chatty=True,
3173                                       sni_name='notfunny')
3174            # Certificate didn't change
3175            self.check_common_name(stats, 'localhost')
3176            self.assertEqual(calls, [])
3177
3178        @needs_sni
3179        def test_sni_callback_alert(self):
3180            # Returning a TLS alert is reflected to the connecting client
3181            server_context, other_context, client_context = self.sni_contexts()
3182
3183            def cb_returning_alert(ssl_sock, server_name, initial_context):
3184                return ssl.ALERT_DESCRIPTION_ACCESS_DENIED
3185            server_context.set_servername_callback(cb_returning_alert)
3186
3187            with self.assertRaises(ssl.SSLError) as cm:
3188                stats = server_params_test(client_context, server_context,
3189                                           chatty=False,
3190                                           sni_name='supermessage')
3191            self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_ACCESS_DENIED')
3192
3193        @needs_sni
3194        def test_sni_callback_raising(self):
3195            # Raising fails the connection with a TLS handshake failure alert.
3196            server_context, other_context, client_context = self.sni_contexts()
3197
3198            def cb_raising(ssl_sock, server_name, initial_context):
3199                1.0/0.0
3200            server_context.set_servername_callback(cb_raising)
3201
3202            with self.assertRaises(ssl.SSLError) as cm, \
3203                 support.captured_stderr() as stderr:
3204                stats = server_params_test(client_context, server_context,
3205                                           chatty=False,
3206                                           sni_name='supermessage')
3207            self.assertEqual(cm.exception.reason, 'SSLV3_ALERT_HANDSHAKE_FAILURE')
3208            self.assertIn("ZeroDivisionError", stderr.getvalue())
3209
3210        @needs_sni
3211        def test_sni_callback_wrong_return_type(self):
3212            # Returning the wrong return type terminates the TLS connection
3213            # with an internal error alert.
3214            server_context, other_context, client_context = self.sni_contexts()
3215
3216            def cb_wrong_return_type(ssl_sock, server_name, initial_context):
3217                return "foo"
3218            server_context.set_servername_callback(cb_wrong_return_type)
3219
3220            with self.assertRaises(ssl.SSLError) as cm, \
3221                 support.captured_stderr() as stderr:
3222                stats = server_params_test(client_context, server_context,
3223                                           chatty=False,
3224                                           sni_name='supermessage')
3225            self.assertEqual(cm.exception.reason, 'TLSV1_ALERT_INTERNAL_ERROR')
3226            self.assertIn("TypeError", stderr.getvalue())
3227
3228        def test_read_write_after_close_raises_valuerror(self):
3229            context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
3230            context.verify_mode = ssl.CERT_REQUIRED
3231            context.load_verify_locations(CERTFILE)
3232            context.load_cert_chain(CERTFILE)
3233            server = ThreadedEchoServer(context=context, chatty=False)
3234
3235            with server:
3236                s = context.wrap_socket(socket.socket())
3237                s.connect((HOST, server.port))
3238                s.close()
3239
3240                self.assertRaises(ValueError, s.read, 1024)
3241                self.assertRaises(ValueError, s.write, b'hello')
3242
3243
3244def test_main(verbose=False):
3245    if support.verbose:
3246        plats = {
3247            'Linux': platform.linux_distribution,
3248            'Mac': platform.mac_ver,
3249            'Windows': platform.win32_ver,
3250        }
3251        for name, func in plats.items():
3252            plat = func()
3253            if plat and plat[0]:
3254                plat = '%s %r' % (name, plat)
3255                break
3256        else:
3257            plat = repr(platform.platform())
3258        print("test_ssl: testing with %r %r" %
3259            (ssl.OPENSSL_VERSION, ssl.OPENSSL_VERSION_INFO))
3260        print("          under %s" % plat)
3261        print("          HAS_SNI = %r" % ssl.HAS_SNI)
3262        print("          OP_ALL = 0x%8x" % ssl.OP_ALL)
3263        try:
3264            print("          OP_NO_TLSv1_1 = 0x%8x" % ssl.OP_NO_TLSv1_1)
3265        except AttributeError:
3266            pass
3267
3268    for filename in [
3269        CERTFILE, REMOTE_ROOT_CERT, BYTES_CERTFILE,
3270        ONLYCERT, ONLYKEY, BYTES_ONLYCERT, BYTES_ONLYKEY,
3271        SIGNED_CERTFILE, SIGNED_CERTFILE2, SIGNING_CA,
3272        BADCERT, BADKEY, EMPTYCERT]:
3273        if not os.path.exists(filename):
3274            raise support.TestFailed("Can't read certificate file %r" % filename)
3275
3276    tests = [ContextTests, BasicTests, BasicSocketTests, SSLErrorTests]
3277
3278    if support.is_resource_enabled('network'):
3279        tests.append(NetworkedTests)
3280
3281    if _have_threads:
3282        thread_info = support.threading_setup()
3283        if thread_info:
3284            tests.append(ThreadedTests)
3285
3286    try:
3287        support.run_unittest(*tests)
3288    finally:
3289        if _have_threads:
3290            support.threading_cleanup(*thread_info)
3291
3292if __name__ == "__main__":
3293    test_main()
3294