1# Wrapper module for _ssl, providing some additional facilities
2# implemented in Python.  Written by Bill Janssen.
3
4"""This module provides some more Pythonic support for SSL.
5
6Object types:
7
8  SSLSocket -- subtype of socket.socket which does SSL over the socket
9
10Exceptions:
11
12  SSLError -- exception raised for I/O errors
13
14Functions:
15
16  cert_time_to_seconds -- convert time string used for certificate
17                          notBefore and notAfter functions to integer
18                          seconds past the Epoch (the time values
19                          returned from time.time())
20
21  fetch_server_certificate (HOST, PORT) -- fetch the certificate provided
22                          by the server running on HOST at port PORT.  No
23                          validation of the certificate is performed.
24
25Integer constants:
26
27SSL_ERROR_ZERO_RETURN
28SSL_ERROR_WANT_READ
29SSL_ERROR_WANT_WRITE
30SSL_ERROR_WANT_X509_LOOKUP
31SSL_ERROR_SYSCALL
32SSL_ERROR_SSL
33SSL_ERROR_WANT_CONNECT
34
35SSL_ERROR_EOF
36SSL_ERROR_INVALID_ERROR_CODE
37
38The following group define certificate requirements that one side is
39allowing/requiring from the other side:
40
41CERT_NONE - no certificates from the other side are required (or will
42            be looked at if provided)
43CERT_OPTIONAL - certificates are not required, but if provided will be
44                validated, and if validation fails, the connection will
45                also fail
46CERT_REQUIRED - certificates are required, and will be validated, and
47                if validation fails, the connection will also fail
48
49The following constants identify various SSL protocol variants:
50
51PROTOCOL_SSLv2
52PROTOCOL_SSLv3
53PROTOCOL_SSLv23
54PROTOCOL_TLS
55PROTOCOL_TLSv1
56PROTOCOL_TLSv1_1
57PROTOCOL_TLSv1_2
58
59The following constants identify various SSL alert message descriptions as per
60http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6
61
62ALERT_DESCRIPTION_CLOSE_NOTIFY
63ALERT_DESCRIPTION_UNEXPECTED_MESSAGE
64ALERT_DESCRIPTION_BAD_RECORD_MAC
65ALERT_DESCRIPTION_RECORD_OVERFLOW
66ALERT_DESCRIPTION_DECOMPRESSION_FAILURE
67ALERT_DESCRIPTION_HANDSHAKE_FAILURE
68ALERT_DESCRIPTION_BAD_CERTIFICATE
69ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE
70ALERT_DESCRIPTION_CERTIFICATE_REVOKED
71ALERT_DESCRIPTION_CERTIFICATE_EXPIRED
72ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN
73ALERT_DESCRIPTION_ILLEGAL_PARAMETER
74ALERT_DESCRIPTION_UNKNOWN_CA
75ALERT_DESCRIPTION_ACCESS_DENIED
76ALERT_DESCRIPTION_DECODE_ERROR
77ALERT_DESCRIPTION_DECRYPT_ERROR
78ALERT_DESCRIPTION_PROTOCOL_VERSION
79ALERT_DESCRIPTION_INSUFFICIENT_SECURITY
80ALERT_DESCRIPTION_INTERNAL_ERROR
81ALERT_DESCRIPTION_USER_CANCELLED
82ALERT_DESCRIPTION_NO_RENEGOTIATION
83ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION
84ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE
85ALERT_DESCRIPTION_UNRECOGNIZED_NAME
86ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE
87ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
88ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY
89"""
90
91import textwrap
92import re
93import sys
94import os
95from collections import namedtuple
96from contextlib import closing
97
98import _ssl             # if we can't import it, let the error propagate
99
100from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
101from _ssl import _SSLContext
102from _ssl import (
103    SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
104    SSLSyscallError, SSLEOFError,
105    )
106from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
107from _ssl import txt2obj as _txt2obj, nid2obj as _nid2obj
108from _ssl import RAND_status, RAND_add
109try:
110    from _ssl import RAND_egd
111except ImportError:
112    # LibreSSL does not provide RAND_egd
113    pass
114
115def _import_symbols(prefix):
116    for n in dir(_ssl):
117        if n.startswith(prefix):
118            globals()[n] = getattr(_ssl, n)
119
120_import_symbols('OP_')
121_import_symbols('ALERT_DESCRIPTION_')
122_import_symbols('SSL_ERROR_')
123_import_symbols('PROTOCOL_')
124_import_symbols('VERIFY_')
125
126from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN, HAS_TLSv1_3
127
128from _ssl import _OPENSSL_API_VERSION
129
130_PROTOCOL_NAMES = {value: name for name, value in globals().items()
131                   if name.startswith('PROTOCOL_')
132                       and name != 'PROTOCOL_SSLv23'}
133PROTOCOL_SSLv23 = PROTOCOL_TLS
134
135try:
136    _SSLv2_IF_EXISTS = PROTOCOL_SSLv2
137except NameError:
138    _SSLv2_IF_EXISTS = None
139
140from socket import socket, _fileobject, _delegate_methods, error as socket_error
141if sys.platform == "win32":
142    from _ssl import enum_certificates, enum_crls
143
144from socket import socket, AF_INET, SOCK_STREAM, create_connection
145from socket import SOL_SOCKET, SO_TYPE
146import base64        # for DER-to-PEM translation
147import errno
148import warnings
149
150if _ssl.HAS_TLS_UNIQUE:
151    CHANNEL_BINDING_TYPES = ['tls-unique']
152else:
153    CHANNEL_BINDING_TYPES = []
154
155
156# Disable weak or insecure ciphers by default
157# (OpenSSL's default setting is 'DEFAULT:!aNULL:!eNULL')
158# Enable a better set of ciphers by default
159# This list has been explicitly chosen to:
160#   * TLS 1.3 ChaCha20 and AES-GCM cipher suites
161#   * Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE)
162#   * Prefer ECDHE over DHE for better performance
163#   * Prefer AEAD over CBC for better performance and security
164#   * Prefer AES-GCM over ChaCha20 because most platforms have AES-NI
165#     (ChaCha20 needs OpenSSL 1.1.0 or patched 1.0.2)
166#   * Prefer any AES-GCM and ChaCha20 over any AES-CBC for better
167#     performance and security
168#   * Then Use HIGH cipher suites as a fallback
169#   * Disable NULL authentication, NULL encryption, 3DES and MD5 MACs
170#     for security reasons
171_DEFAULT_CIPHERS = (
172    'TLS13-AES-256-GCM-SHA384:TLS13-CHACHA20-POLY1305-SHA256:'
173    'TLS13-AES-128-GCM-SHA256:'
174    'ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:'
175    'ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:'
176    '!aNULL:!eNULL:!MD5:!3DES'
177    )
178
179# Restricted and more secure ciphers for the server side
180# This list has been explicitly chosen to:
181#   * TLS 1.3 ChaCha20 and AES-GCM cipher suites
182#   * Prefer cipher suites that offer perfect forward secrecy (DHE/ECDHE)
183#   * Prefer ECDHE over DHE for better performance
184#   * Prefer AEAD over CBC for better performance and security
185#   * Prefer AES-GCM over ChaCha20 because most platforms have AES-NI
186#   * Prefer any AES-GCM and ChaCha20 over any AES-CBC for better
187#     performance and security
188#   * Then Use HIGH cipher suites as a fallback
189#   * Disable NULL authentication, NULL encryption, MD5 MACs, DSS, RC4, and
190#     3DES for security reasons
191_RESTRICTED_SERVER_CIPHERS = (
192    'TLS13-AES-256-GCM-SHA384:TLS13-CHACHA20-POLY1305-SHA256:'
193    'TLS13-AES-128-GCM-SHA256:'
194    'ECDH+AESGCM:ECDH+CHACHA20:DH+AESGCM:DH+CHACHA20:ECDH+AES256:DH+AES256:'
195    'ECDH+AES128:DH+AES:ECDH+HIGH:DH+HIGH:RSA+AESGCM:RSA+AES:RSA+HIGH:'
196    '!aNULL:!eNULL:!MD5:!DSS:!RC4:!3DES'
197)
198
199
200class CertificateError(ValueError):
201    pass
202
203
204def _dnsname_match(dn, hostname, max_wildcards=1):
205    """Matching according to RFC 6125, section 6.4.3
206
207    http://tools.ietf.org/html/rfc6125#section-6.4.3
208    """
209    pats = []
210    if not dn:
211        return False
212
213    pieces = dn.split(r'.')
214    leftmost = pieces[0]
215    remainder = pieces[1:]
216
217    wildcards = leftmost.count('*')
218    if wildcards > max_wildcards:
219        # Issue #17980: avoid denials of service by refusing more
220        # than one wildcard per fragment.  A survery of established
221        # policy among SSL implementations showed it to be a
222        # reasonable choice.
223        raise CertificateError(
224            "too many wildcards in certificate DNS name: " + repr(dn))
225
226    # speed up common case w/o wildcards
227    if not wildcards:
228        return dn.lower() == hostname.lower()
229
230    # RFC 6125, section 6.4.3, subitem 1.
231    # The client SHOULD NOT attempt to match a presented identifier in which
232    # the wildcard character comprises a label other than the left-most label.
233    if leftmost == '*':
234        # When '*' is a fragment by itself, it matches a non-empty dotless
235        # fragment.
236        pats.append('[^.]+')
237    elif leftmost.startswith('xn--') or hostname.startswith('xn--'):
238        # RFC 6125, section 6.4.3, subitem 3.
239        # The client SHOULD NOT attempt to match a presented identifier
240        # where the wildcard character is embedded within an A-label or
241        # U-label of an internationalized domain name.
242        pats.append(re.escape(leftmost))
243    else:
244        # Otherwise, '*' matches any dotless string, e.g. www*
245        pats.append(re.escape(leftmost).replace(r'\*', '[^.]*'))
246
247    # add the remaining fragments, ignore any wildcards
248    for frag in remainder:
249        pats.append(re.escape(frag))
250
251    pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE)
252    return pat.match(hostname)
253
254
255def match_hostname(cert, hostname):
256    """Verify that *cert* (in decoded format as returned by
257    SSLSocket.getpeercert()) matches the *hostname*.  RFC 2818 and RFC 6125
258    rules are followed, but IP addresses are not accepted for *hostname*.
259
260    CertificateError is raised on failure. On success, the function
261    returns nothing.
262    """
263    if not cert:
264        raise ValueError("empty or no certificate, match_hostname needs a "
265                         "SSL socket or SSL context with either "
266                         "CERT_OPTIONAL or CERT_REQUIRED")
267    dnsnames = []
268    san = cert.get('subjectAltName', ())
269    for key, value in san:
270        if key == 'DNS':
271            if _dnsname_match(value, hostname):
272                return
273            dnsnames.append(value)
274    if not dnsnames:
275        # The subject is only checked when there is no dNSName entry
276        # in subjectAltName
277        for sub in cert.get('subject', ()):
278            for key, value in sub:
279                # XXX according to RFC 2818, the most specific Common Name
280                # must be used.
281                if key == 'commonName':
282                    if _dnsname_match(value, hostname):
283                        return
284                    dnsnames.append(value)
285    if len(dnsnames) > 1:
286        raise CertificateError("hostname %r "
287            "doesn't match either of %s"
288            % (hostname, ', '.join(map(repr, dnsnames))))
289    elif len(dnsnames) == 1:
290        raise CertificateError("hostname %r "
291            "doesn't match %r"
292            % (hostname, dnsnames[0]))
293    else:
294        raise CertificateError("no appropriate commonName or "
295            "subjectAltName fields were found")
296
297
298DefaultVerifyPaths = namedtuple("DefaultVerifyPaths",
299    "cafile capath openssl_cafile_env openssl_cafile openssl_capath_env "
300    "openssl_capath")
301
302def get_default_verify_paths():
303    """Return paths to default cafile and capath.
304    """
305    parts = _ssl.get_default_verify_paths()
306
307    # environment vars shadow paths
308    cafile = os.environ.get(parts[0], parts[1])
309    capath = os.environ.get(parts[2], parts[3])
310
311    return DefaultVerifyPaths(cafile if os.path.isfile(cafile) else None,
312                              capath if os.path.isdir(capath) else None,
313                              *parts)
314
315
316class _ASN1Object(namedtuple("_ASN1Object", "nid shortname longname oid")):
317    """ASN.1 object identifier lookup
318    """
319    __slots__ = ()
320
321    def __new__(cls, oid):
322        return super(_ASN1Object, cls).__new__(cls, *_txt2obj(oid, name=False))
323
324    @classmethod
325    def fromnid(cls, nid):
326        """Create _ASN1Object from OpenSSL numeric ID
327        """
328        return super(_ASN1Object, cls).__new__(cls, *_nid2obj(nid))
329
330    @classmethod
331    def fromname(cls, name):
332        """Create _ASN1Object from short name, long name or OID
333        """
334        return super(_ASN1Object, cls).__new__(cls, *_txt2obj(name, name=True))
335
336
337class Purpose(_ASN1Object):
338    """SSLContext purpose flags with X509v3 Extended Key Usage objects
339    """
340
341Purpose.SERVER_AUTH = Purpose('1.3.6.1.5.5.7.3.1')
342Purpose.CLIENT_AUTH = Purpose('1.3.6.1.5.5.7.3.2')
343
344
345class SSLContext(_SSLContext):
346    """An SSLContext holds various SSL-related configuration options and
347    data, such as certificates and possibly a private key."""
348
349    __slots__ = ('protocol', '__weakref__')
350    _windows_cert_stores = ("CA", "ROOT")
351
352    def __new__(cls, protocol, *args, **kwargs):
353        self = _SSLContext.__new__(cls, protocol)
354        if protocol != _SSLv2_IF_EXISTS:
355            self.set_ciphers(_DEFAULT_CIPHERS)
356        return self
357
358    def __init__(self, protocol):
359        self.protocol = protocol
360
361    def wrap_socket(self, sock, server_side=False,
362                    do_handshake_on_connect=True,
363                    suppress_ragged_eofs=True,
364                    server_hostname=None):
365        return SSLSocket(sock=sock, server_side=server_side,
366                         do_handshake_on_connect=do_handshake_on_connect,
367                         suppress_ragged_eofs=suppress_ragged_eofs,
368                         server_hostname=server_hostname,
369                         _context=self)
370
371    def set_npn_protocols(self, npn_protocols):
372        protos = bytearray()
373        for protocol in npn_protocols:
374            b = protocol.encode('ascii')
375            if len(b) == 0 or len(b) > 255:
376                raise SSLError('NPN protocols must be 1 to 255 in length')
377            protos.append(len(b))
378            protos.extend(b)
379
380        self._set_npn_protocols(protos)
381
382    def set_alpn_protocols(self, alpn_protocols):
383        protos = bytearray()
384        for protocol in alpn_protocols:
385            b = protocol.encode('ascii')
386            if len(b) == 0 or len(b) > 255:
387                raise SSLError('ALPN protocols must be 1 to 255 in length')
388            protos.append(len(b))
389            protos.extend(b)
390
391        self._set_alpn_protocols(protos)
392
393    def _load_windows_store_certs(self, storename, purpose):
394        certs = bytearray()
395        try:
396            for cert, encoding, trust in enum_certificates(storename):
397                # CA certs are never PKCS#7 encoded
398                if encoding == "x509_asn":
399                    if trust is True or purpose.oid in trust:
400                        certs.extend(cert)
401        except OSError:
402            warnings.warn("unable to enumerate Windows certificate store")
403        if certs:
404            self.load_verify_locations(cadata=certs)
405        return certs
406
407    def load_default_certs(self, purpose=Purpose.SERVER_AUTH):
408        if not isinstance(purpose, _ASN1Object):
409            raise TypeError(purpose)
410        if sys.platform == "win32":
411            for storename in self._windows_cert_stores:
412                self._load_windows_store_certs(storename, purpose)
413        self.set_default_verify_paths()
414
415
416def create_default_context(purpose=Purpose.SERVER_AUTH, cafile=None,
417                           capath=None, cadata=None):
418    """Create a SSLContext object with default settings.
419
420    NOTE: The protocol and settings may change anytime without prior
421          deprecation. The values represent a fair balance between maximum
422          compatibility and security.
423    """
424    if not isinstance(purpose, _ASN1Object):
425        raise TypeError(purpose)
426
427    # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
428    # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
429    # by default.
430    context = SSLContext(PROTOCOL_TLS)
431
432    if purpose == Purpose.SERVER_AUTH:
433        # verify certs and host name in client mode
434        context.verify_mode = CERT_REQUIRED
435        context.check_hostname = True
436    elif purpose == Purpose.CLIENT_AUTH:
437        context.set_ciphers(_RESTRICTED_SERVER_CIPHERS)
438
439    if cafile or capath or cadata:
440        context.load_verify_locations(cafile, capath, cadata)
441    elif context.verify_mode != CERT_NONE:
442        # no explicit cafile, capath or cadata but the verify mode is
443        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
444        # root CA certificates for the given purpose. This may fail silently.
445        context.load_default_certs(purpose)
446    return context
447
448def _create_unverified_context(protocol=PROTOCOL_TLS, cert_reqs=None,
449                           check_hostname=False, purpose=Purpose.SERVER_AUTH,
450                           certfile=None, keyfile=None,
451                           cafile=None, capath=None, cadata=None):
452    """Create a SSLContext object for Python stdlib modules
453
454    All Python stdlib modules shall use this function to create SSLContext
455    objects in order to keep common settings in one place. The configuration
456    is less restrict than create_default_context()'s to increase backward
457    compatibility.
458    """
459    if not isinstance(purpose, _ASN1Object):
460        raise TypeError(purpose)
461
462    # SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION,
463    # OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE
464    # by default.
465    context = SSLContext(protocol)
466
467    if cert_reqs is not None:
468        context.verify_mode = cert_reqs
469    context.check_hostname = check_hostname
470
471    if keyfile and not certfile:
472        raise ValueError("certfile must be specified")
473    if certfile or keyfile:
474        context.load_cert_chain(certfile, keyfile)
475
476    # load CA root certs
477    if cafile or capath or cadata:
478        context.load_verify_locations(cafile, capath, cadata)
479    elif context.verify_mode != CERT_NONE:
480        # no explicit cafile, capath or cadata but the verify mode is
481        # CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system
482        # root CA certificates for the given purpose. This may fail silently.
483        context.load_default_certs(purpose)
484
485    return context
486
487# Backwards compatibility alias, even though it's not a public name.
488_create_stdlib_context = _create_unverified_context
489
490# PEP 493: Verify HTTPS by default, but allow envvar to override that
491_https_verify_envvar = 'PYTHONHTTPSVERIFY'
492
493def _get_https_context_factory():
494    if not sys.flags.ignore_environment:
495        config_setting = os.environ.get(_https_verify_envvar)
496        if config_setting == '0':
497            return _create_unverified_context
498    return create_default_context
499
500_create_default_https_context = _get_https_context_factory()
501
502# PEP 493: "private" API to configure HTTPS defaults without monkeypatching
503def _https_verify_certificates(enable=True):
504    """Verify server HTTPS certificates by default?"""
505    global _create_default_https_context
506    if enable:
507        _create_default_https_context = create_default_context
508    else:
509        _create_default_https_context = _create_unverified_context
510
511
512class SSLSocket(socket):
513    """This class implements a subtype of socket.socket that wraps
514    the underlying OS socket in an SSL context when necessary, and
515    provides read and write methods over that channel."""
516
517    def __init__(self, sock=None, keyfile=None, certfile=None,
518                 server_side=False, cert_reqs=CERT_NONE,
519                 ssl_version=PROTOCOL_TLS, ca_certs=None,
520                 do_handshake_on_connect=True,
521                 family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
522                 suppress_ragged_eofs=True, npn_protocols=None, ciphers=None,
523                 server_hostname=None,
524                 _context=None):
525
526        self._makefile_refs = 0
527        if _context:
528            self._context = _context
529        else:
530            if server_side and not certfile:
531                raise ValueError("certfile must be specified for server-side "
532                                 "operations")
533            if keyfile and not certfile:
534                raise ValueError("certfile must be specified")
535            if certfile and not keyfile:
536                keyfile = certfile
537            self._context = SSLContext(ssl_version)
538            self._context.verify_mode = cert_reqs
539            if ca_certs:
540                self._context.load_verify_locations(ca_certs)
541            if certfile:
542                self._context.load_cert_chain(certfile, keyfile)
543            if npn_protocols:
544                self._context.set_npn_protocols(npn_protocols)
545            if ciphers:
546                self._context.set_ciphers(ciphers)
547            self.keyfile = keyfile
548            self.certfile = certfile
549            self.cert_reqs = cert_reqs
550            self.ssl_version = ssl_version
551            self.ca_certs = ca_certs
552            self.ciphers = ciphers
553        # Can't use sock.type as other flags (such as SOCK_NONBLOCK) get
554        # mixed in.
555        if sock.getsockopt(SOL_SOCKET, SO_TYPE) != SOCK_STREAM:
556            raise NotImplementedError("only stream sockets are supported")
557        socket.__init__(self, _sock=sock._sock)
558        # The initializer for socket overrides the methods send(), recv(), etc.
559        # in the instancce, which we don't need -- but we want to provide the
560        # methods defined in SSLSocket.
561        for attr in _delegate_methods:
562            try:
563                delattr(self, attr)
564            except AttributeError:
565                pass
566        if server_side and server_hostname:
567            raise ValueError("server_hostname can only be specified "
568                             "in client mode")
569        if self._context.check_hostname and not server_hostname:
570            raise ValueError("check_hostname requires server_hostname")
571        self.server_side = server_side
572        self.server_hostname = server_hostname
573        self.do_handshake_on_connect = do_handshake_on_connect
574        self.suppress_ragged_eofs = suppress_ragged_eofs
575
576        # See if we are connected
577        try:
578            self.getpeername()
579        except socket_error as e:
580            if e.errno != errno.ENOTCONN:
581                raise
582            connected = False
583        else:
584            connected = True
585
586        self._closed = False
587        self._sslobj = None
588        self._connected = connected
589        if connected:
590            # create the SSL object
591            try:
592                self._sslobj = self._context._wrap_socket(self._sock, server_side,
593                                                          server_hostname, ssl_sock=self)
594                if do_handshake_on_connect:
595                    timeout = self.gettimeout()
596                    if timeout == 0.0:
597                        # non-blocking
598                        raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
599                    self.do_handshake()
600
601            except (OSError, ValueError):
602                self.close()
603                raise
604
605    @property
606    def context(self):
607        return self._context
608
609    @context.setter
610    def context(self, ctx):
611        self._context = ctx
612        self._sslobj.context = ctx
613
614    def dup(self):
615        raise NotImplementedError("Can't dup() %s instances" %
616                                  self.__class__.__name__)
617
618    def _checkClosed(self, msg=None):
619        # raise an exception here if you wish to check for spurious closes
620        pass
621
622    def _check_connected(self):
623        if not self._connected:
624            # getpeername() will raise ENOTCONN if the socket is really
625            # not connected; note that we can be connected even without
626            # _connected being set, e.g. if connect() first returned
627            # EAGAIN.
628            self.getpeername()
629
630    def read(self, len=1024, buffer=None):
631        """Read up to LEN bytes and return them.
632        Return zero-length string on EOF."""
633
634        self._checkClosed()
635        if not self._sslobj:
636            raise ValueError("Read on closed or unwrapped SSL socket.")
637        try:
638            if buffer is not None:
639                v = self._sslobj.read(len, buffer)
640            else:
641                v = self._sslobj.read(len)
642            return v
643        except SSLError as x:
644            if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
645                if buffer is not None:
646                    return 0
647                else:
648                    return b''
649            else:
650                raise
651
652    def write(self, data):
653        """Write DATA to the underlying SSL channel.  Returns
654        number of bytes of DATA actually transmitted."""
655
656        self._checkClosed()
657        if not self._sslobj:
658            raise ValueError("Write on closed or unwrapped SSL socket.")
659        return self._sslobj.write(data)
660
661    def getpeercert(self, binary_form=False):
662        """Returns a formatted version of the data in the
663        certificate provided by the other end of the SSL channel.
664        Return None if no certificate was provided, {} if a
665        certificate was provided, but not validated."""
666
667        self._checkClosed()
668        self._check_connected()
669        return self._sslobj.peer_certificate(binary_form)
670
671    def selected_npn_protocol(self):
672        self._checkClosed()
673        if not self._sslobj or not _ssl.HAS_NPN:
674            return None
675        else:
676            return self._sslobj.selected_npn_protocol()
677
678    def selected_alpn_protocol(self):
679        self._checkClosed()
680        if not self._sslobj or not _ssl.HAS_ALPN:
681            return None
682        else:
683            return self._sslobj.selected_alpn_protocol()
684
685    def cipher(self):
686        self._checkClosed()
687        if not self._sslobj:
688            return None
689        else:
690            return self._sslobj.cipher()
691
692    def compression(self):
693        self._checkClosed()
694        if not self._sslobj:
695            return None
696        else:
697            return self._sslobj.compression()
698
699    def send(self, data, flags=0):
700        self._checkClosed()
701        if self._sslobj:
702            if flags != 0:
703                raise ValueError(
704                    "non-zero flags not allowed in calls to send() on %s" %
705                    self.__class__)
706            try:
707                v = self._sslobj.write(data)
708            except SSLError as x:
709                if x.args[0] == SSL_ERROR_WANT_READ:
710                    return 0
711                elif x.args[0] == SSL_ERROR_WANT_WRITE:
712                    return 0
713                else:
714                    raise
715            else:
716                return v
717        else:
718            return self._sock.send(data, flags)
719
720    def sendto(self, data, flags_or_addr, addr=None):
721        self._checkClosed()
722        if self._sslobj:
723            raise ValueError("sendto not allowed on instances of %s" %
724                             self.__class__)
725        elif addr is None:
726            return self._sock.sendto(data, flags_or_addr)
727        else:
728            return self._sock.sendto(data, flags_or_addr, addr)
729
730
731    def sendall(self, data, flags=0):
732        self._checkClosed()
733        if self._sslobj:
734            if flags != 0:
735                raise ValueError(
736                    "non-zero flags not allowed in calls to sendall() on %s" %
737                    self.__class__)
738            amount = len(data)
739            count = 0
740            while (count < amount):
741                v = self.send(data[count:])
742                count += v
743            return amount
744        else:
745            return socket.sendall(self, data, flags)
746
747    def recv(self, buflen=1024, flags=0):
748        self._checkClosed()
749        if self._sslobj:
750            if flags != 0:
751                raise ValueError(
752                    "non-zero flags not allowed in calls to recv() on %s" %
753                    self.__class__)
754            return self.read(buflen)
755        else:
756            return self._sock.recv(buflen, flags)
757
758    def recv_into(self, buffer, nbytes=None, flags=0):
759        self._checkClosed()
760        if buffer and (nbytes is None):
761            nbytes = len(buffer)
762        elif nbytes is None:
763            nbytes = 1024
764        if self._sslobj:
765            if flags != 0:
766                raise ValueError(
767                  "non-zero flags not allowed in calls to recv_into() on %s" %
768                  self.__class__)
769            return self.read(nbytes, buffer)
770        else:
771            return self._sock.recv_into(buffer, nbytes, flags)
772
773    def recvfrom(self, buflen=1024, flags=0):
774        self._checkClosed()
775        if self._sslobj:
776            raise ValueError("recvfrom not allowed on instances of %s" %
777                             self.__class__)
778        else:
779            return self._sock.recvfrom(buflen, flags)
780
781    def recvfrom_into(self, buffer, nbytes=None, flags=0):
782        self._checkClosed()
783        if self._sslobj:
784            raise ValueError("recvfrom_into not allowed on instances of %s" %
785                             self.__class__)
786        else:
787            return self._sock.recvfrom_into(buffer, nbytes, flags)
788
789
790    def pending(self):
791        self._checkClosed()
792        if self._sslobj:
793            return self._sslobj.pending()
794        else:
795            return 0
796
797    def shutdown(self, how):
798        self._checkClosed()
799        self._sslobj = None
800        socket.shutdown(self, how)
801
802    def close(self):
803        if self._makefile_refs < 1:
804            self._sslobj = None
805            socket.close(self)
806        else:
807            self._makefile_refs -= 1
808
809    def unwrap(self):
810        if self._sslobj:
811            s = self._sslobj.shutdown()
812            self._sslobj = None
813            return s
814        else:
815            raise ValueError("No SSL wrapper around " + str(self))
816
817    def _real_close(self):
818        self._sslobj = None
819        socket._real_close(self)
820
821    def do_handshake(self, block=False):
822        """Perform a TLS/SSL handshake."""
823        self._check_connected()
824        timeout = self.gettimeout()
825        try:
826            if timeout == 0.0 and block:
827                self.settimeout(None)
828            self._sslobj.do_handshake()
829        finally:
830            self.settimeout(timeout)
831
832        if self.context.check_hostname:
833            if not self.server_hostname:
834                raise ValueError("check_hostname needs server_hostname "
835                                 "argument")
836            match_hostname(self.getpeercert(), self.server_hostname)
837
838    def _real_connect(self, addr, connect_ex):
839        if self.server_side:
840            raise ValueError("can't connect in server-side mode")
841        # Here we assume that the socket is client-side, and not
842        # connected at the time of the call.  We connect it, then wrap it.
843        if self._connected:
844            raise ValueError("attempt to connect already-connected SSLSocket!")
845        self._sslobj = self.context._wrap_socket(self._sock, False, self.server_hostname, ssl_sock=self)
846        try:
847            if connect_ex:
848                rc = socket.connect_ex(self, addr)
849            else:
850                rc = None
851                socket.connect(self, addr)
852            if not rc:
853                self._connected = True
854                if self.do_handshake_on_connect:
855                    self.do_handshake()
856            return rc
857        except (OSError, ValueError):
858            self._sslobj = None
859            raise
860
861    def connect(self, addr):
862        """Connects to remote ADDR, and then wraps the connection in
863        an SSL channel."""
864        self._real_connect(addr, False)
865
866    def connect_ex(self, addr):
867        """Connects to remote ADDR, and then wraps the connection in
868        an SSL channel."""
869        return self._real_connect(addr, True)
870
871    def accept(self):
872        """Accepts a new connection from a remote client, and returns
873        a tuple containing that new connection wrapped with a server-side
874        SSL channel, and the address of the remote client."""
875
876        newsock, addr = socket.accept(self)
877        newsock = self.context.wrap_socket(newsock,
878                    do_handshake_on_connect=self.do_handshake_on_connect,
879                    suppress_ragged_eofs=self.suppress_ragged_eofs,
880                    server_side=True)
881        return newsock, addr
882
883    def makefile(self, mode='r', bufsize=-1):
884
885        """Make and return a file-like object that
886        works with the SSL connection.  Just use the code
887        from the socket module."""
888
889        self._makefile_refs += 1
890        # close=True so as to decrement the reference count when done with
891        # the file-like object.
892        return _fileobject(self, mode, bufsize, close=True)
893
894    def get_channel_binding(self, cb_type="tls-unique"):
895        """Get channel binding data for current connection.  Raise ValueError
896        if the requested `cb_type` is not supported.  Return bytes of the data
897        or None if the data is not available (e.g. before the handshake).
898        """
899        if cb_type not in CHANNEL_BINDING_TYPES:
900            raise ValueError("Unsupported channel binding type")
901        if cb_type != "tls-unique":
902            raise NotImplementedError(
903                            "{0} channel binding type not implemented"
904                            .format(cb_type))
905        if self._sslobj is None:
906            return None
907        return self._sslobj.tls_unique_cb()
908
909    def version(self):
910        """
911        Return a string identifying the protocol version used by the
912        current SSL channel, or None if there is no established channel.
913        """
914        if self._sslobj is None:
915            return None
916        return self._sslobj.version()
917
918
919def wrap_socket(sock, keyfile=None, certfile=None,
920                server_side=False, cert_reqs=CERT_NONE,
921                ssl_version=PROTOCOL_TLS, ca_certs=None,
922                do_handshake_on_connect=True,
923                suppress_ragged_eofs=True,
924                ciphers=None):
925
926    return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
927                     server_side=server_side, cert_reqs=cert_reqs,
928                     ssl_version=ssl_version, ca_certs=ca_certs,
929                     do_handshake_on_connect=do_handshake_on_connect,
930                     suppress_ragged_eofs=suppress_ragged_eofs,
931                     ciphers=ciphers)
932
933# some utility functions
934
935def cert_time_to_seconds(cert_time):
936    """Return the time in seconds since the Epoch, given the timestring
937    representing the "notBefore" or "notAfter" date from a certificate
938    in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale).
939
940    "notBefore" or "notAfter" dates must use UTC (RFC 5280).
941
942    Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
943    UTC should be specified as GMT (see ASN1_TIME_print())
944    """
945    from time import strptime
946    from calendar import timegm
947
948    months = (
949        "Jan","Feb","Mar","Apr","May","Jun",
950        "Jul","Aug","Sep","Oct","Nov","Dec"
951    )
952    time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT
953    try:
954        month_number = months.index(cert_time[:3].title()) + 1
955    except ValueError:
956        raise ValueError('time data %r does not match '
957                         'format "%%b%s"' % (cert_time, time_format))
958    else:
959        # found valid month
960        tt = strptime(cert_time[3:], time_format)
961        # return an integer, the previous mktime()-based implementation
962        # returned a float (fractional seconds are always zero here).
963        return timegm((tt[0], month_number) + tt[2:6])
964
965PEM_HEADER = "-----BEGIN CERTIFICATE-----"
966PEM_FOOTER = "-----END CERTIFICATE-----"
967
968def DER_cert_to_PEM_cert(der_cert_bytes):
969    """Takes a certificate in binary DER format and returns the
970    PEM version of it as a string."""
971
972    f = base64.standard_b64encode(der_cert_bytes).decode('ascii')
973    return (PEM_HEADER + '\n' +
974            textwrap.fill(f, 64) + '\n' +
975            PEM_FOOTER + '\n')
976
977def PEM_cert_to_DER_cert(pem_cert_string):
978    """Takes a certificate in ASCII PEM format and returns the
979    DER-encoded version of it as a byte sequence"""
980
981    if not pem_cert_string.startswith(PEM_HEADER):
982        raise ValueError("Invalid PEM encoding; must start with %s"
983                         % PEM_HEADER)
984    if not pem_cert_string.strip().endswith(PEM_FOOTER):
985        raise ValueError("Invalid PEM encoding; must end with %s"
986                         % PEM_FOOTER)
987    d = pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)]
988    return base64.decodestring(d.encode('ASCII', 'strict'))
989
990def get_server_certificate(addr, ssl_version=PROTOCOL_TLS, ca_certs=None):
991    """Retrieve the certificate from the server at the specified address,
992    and return it as a PEM-encoded string.
993    If 'ca_certs' is specified, validate the server cert against it.
994    If 'ssl_version' is specified, use it in the connection attempt."""
995
996    host, port = addr
997    if ca_certs is not None:
998        cert_reqs = CERT_REQUIRED
999    else:
1000        cert_reqs = CERT_NONE
1001    context = _create_stdlib_context(ssl_version,
1002                                     cert_reqs=cert_reqs,
1003                                     cafile=ca_certs)
1004    with closing(create_connection(addr)) as sock:
1005        with closing(context.wrap_socket(sock)) as sslsock:
1006            dercert = sslsock.getpeercert(True)
1007    return DER_cert_to_PEM_cert(dercert)
1008
1009def get_protocol_name(protocol_code):
1010    return _PROTOCOL_NAMES.get(protocol_code, '<unknown>')
1011
1012
1013# a replacement for the old socket.ssl function
1014
1015def sslwrap_simple(sock, keyfile=None, certfile=None):
1016    """A replacement for the old socket.ssl function.  Designed
1017    for compability with Python 2.5 and earlier.  Will disappear in
1018    Python 3.0."""
1019    if hasattr(sock, "_sock"):
1020        sock = sock._sock
1021
1022    ctx = SSLContext(PROTOCOL_SSLv23)
1023    if keyfile or certfile:
1024        ctx.load_cert_chain(certfile, keyfile)
1025    ssl_sock = ctx._wrap_socket(sock, server_side=False)
1026    try:
1027        sock.getpeername()
1028    except socket_error:
1029        # no, no connection yet
1030        pass
1031    else:
1032        # yes, do the handshake
1033        ssl_sock.do_handshake()
1034
1035    return ssl_sock
1036