1"""
2"""
3
4# Created on 2013.08.05
5#
6# Author: Giovanni Cannata
7#
8# Copyright 2013 - 2020 Giovanni Cannata
9#
10# This file is part of ldap3.
11#
12# ldap3 is free software: you can redistribute it and/or modify
13# it under the terms of the GNU Lesser General Public License as published
14# by the Free Software Foundation, either version 3 of the License, or
15# (at your option) any later version.
16#
17# ldap3 is distributed in the hope that it will be useful,
18# but WITHOUT ANY WARRANTY; without even the implied warranty of
19# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
20# GNU Lesser General Public License for more details.
21#
22# You should have received a copy of the GNU Lesser General Public License
23# along with ldap3 in the COPYING and COPYING.LESSER files.
24# If not, see <http://www.gnu.org/licenses/>.
25
26from .exceptions import LDAPSSLNotSupportedError, LDAPSSLConfigurationError, LDAPCertificateError, start_tls_exception_factory, LDAPStartTLSError
27from .. import SEQUENCE_TYPES
28from ..utils.log import log, log_enabled, ERROR, BASIC, NETWORK
29
30try:
31    # noinspection PyUnresolvedReferences
32    import ssl
33except ImportError:
34    if log_enabled(ERROR):
35        log(ERROR, 'SSL not supported in this Python interpreter')
36    raise LDAPSSLNotSupportedError('SSL not supported in this Python interpreter')
37
38try:
39    from ssl import match_hostname, CertificateError  # backport for python2 missing ssl functionalities
40except ImportError:
41    from ..utils.tls_backport import CertificateError
42    from ..utils.tls_backport import match_hostname
43    if log_enabled(BASIC):
44        log(BASIC, 'using tls_backport')
45
46try:  # try to use SSLContext
47    # noinspection PyUnresolvedReferences
48    from ssl import create_default_context, Purpose  # defined in Python 3.4 and Python 2.7.9
49    use_ssl_context = True
50except ImportError:
51    use_ssl_context = False
52    if log_enabled(BASIC):
53        log(BASIC, 'SSLContext unavailable')
54
55from os import path
56
57
58# noinspection PyProtectedMember
59class Tls(object):
60    """
61    tls/ssl configuration for Server object
62    Starting from python 2.7.9 and python 3.4 uses the SSLContext object
63    that tries to read the CAs defined at system level
64    ca_certs_path and ca_certs_data are valid only when using SSLContext
65    local_private_key_password is valid only when using SSLContext
66    ssl_options is valid only when using SSLContext
67    sni is the server name for Server Name Indication (when available)
68    """
69
70    def __init__(self,
71                 local_private_key_file=None,
72                 local_certificate_file=None,
73                 validate=ssl.CERT_NONE,
74                 version=None,
75                 ssl_options=None,
76                 ca_certs_file=None,
77                 valid_names=None,
78                 ca_certs_path=None,
79                 ca_certs_data=None,
80                 local_private_key_password=None,
81                 ciphers=None,
82                 sni=None):
83        if ssl_options is None:
84            ssl_options = []
85        self.ssl_options = ssl_options
86        if validate in [ssl.CERT_NONE, ssl.CERT_OPTIONAL, ssl.CERT_REQUIRED]:
87            self.validate = validate
88        elif validate:
89            if log_enabled(ERROR):
90                log(ERROR, 'invalid validate parameter <%s>', validate)
91            raise LDAPSSLConfigurationError('invalid validate parameter')
92        if ca_certs_file and path.exists(ca_certs_file):
93            self.ca_certs_file = ca_certs_file
94        elif ca_certs_file:
95            if log_enabled(ERROR):
96                log(ERROR, 'invalid CA public key file <%s>', ca_certs_file)
97            raise LDAPSSLConfigurationError('invalid CA public key file')
98        else:
99            self.ca_certs_file = None
100
101        if ca_certs_path and use_ssl_context and path.exists(ca_certs_path):
102            self.ca_certs_path = ca_certs_path
103        elif ca_certs_path and not use_ssl_context:
104            if log_enabled(ERROR):
105                log(ERROR, 'cannot use CA public keys path, SSLContext not available')
106            raise LDAPSSLNotSupportedError('cannot use CA public keys path, SSLContext not available')
107        elif ca_certs_path:
108            if log_enabled(ERROR):
109                log(ERROR, 'invalid CA public keys path <%s>', ca_certs_path)
110            raise LDAPSSLConfigurationError('invalid CA public keys path')
111        else:
112            self.ca_certs_path = None
113
114        if ca_certs_data and use_ssl_context:
115            self.ca_certs_data = ca_certs_data
116        elif ca_certs_data:
117            if log_enabled(ERROR):
118                log(ERROR, 'cannot use CA data, SSLContext not available')
119            raise LDAPSSLNotSupportedError('cannot use CA data, SSLContext not available')
120        else:
121            self.ca_certs_data = None
122
123        if local_private_key_password and use_ssl_context:
124            self.private_key_password = local_private_key_password
125        elif local_private_key_password:
126            if log_enabled(ERROR):
127                log(ERROR, 'cannot use local private key password, SSLContext not available')
128            raise LDAPSSLNotSupportedError('cannot use local private key password, SSLContext is not available')
129        else:
130            self.private_key_password = None
131
132        self.version = version
133        self.private_key_file = local_private_key_file
134        self.certificate_file = local_certificate_file
135        self.valid_names = valid_names
136        self.ciphers = ciphers
137        self.sni = sni
138
139        if log_enabled(BASIC):
140            log(BASIC, 'instantiated Tls: <%r>' % self)
141
142    def __str__(self):
143        s = [
144            'protocol: ' + str(self.version),
145            'client private key: ' + ('present ' if self.private_key_file else 'not present'),
146            'client certificate: ' + ('present ' if self.certificate_file else 'not present'),
147            'private key password: ' + ('present ' if self.private_key_password else 'not present'),
148            'CA certificates file: ' + ('present ' if self.ca_certs_file else 'not present'),
149            'CA certificates path: ' + ('present ' if self.ca_certs_path else 'not present'),
150            'CA certificates data: ' + ('present ' if self.ca_certs_data else 'not present'),
151            'verify mode: ' + str(self.validate),
152            'valid names: ' + str(self.valid_names),
153            'ciphers: ' + str(self.ciphers),
154            'sni: ' + str(self.sni)
155        ]
156        return ' - '.join(s)
157
158    def __repr__(self):
159        r = '' if self.private_key_file is None else ', local_private_key_file={0.private_key_file!r}'.format(self)
160        r += '' if self.certificate_file is None else ', local_certificate_file={0.certificate_file!r}'.format(self)
161        r += '' if self.validate is None else ', validate={0.validate!r}'.format(self)
162        r += '' if self.version is None else ', version={0.version!r}'.format(self)
163        r += '' if self.ca_certs_file is None else ', ca_certs_file={0.ca_certs_file!r}'.format(self)
164        r += '' if self.ca_certs_path is None else ', ca_certs_path={0.ca_certs_path!r}'.format(self)
165        r += '' if self.ca_certs_data is None else ', ca_certs_data={0.ca_certs_data!r}'.format(self)
166        r += '' if self.ciphers is None else ', ciphers={0.ciphers!r}'.format(self)
167        r += '' if self.sni is None else ', sni={0.sni!r}'.format(self)
168        r = 'Tls(' + r[2:] + ')'
169        return r
170
171    def wrap_socket(self, connection, do_handshake=False):
172        """
173        Adds TLS to the connection socket
174        """
175        if use_ssl_context:
176            if self.version is None:  # uses the default ssl context for reasonable security
177                ssl_context = create_default_context(purpose=Purpose.SERVER_AUTH,
178                                                     cafile=self.ca_certs_file,
179                                                     capath=self.ca_certs_path,
180                                                     cadata=self.ca_certs_data)
181            else:  # code from create_default_context in the Python standard library 3.5.1, creates a ssl context with the specificd protocol version
182                ssl_context = ssl.SSLContext(self.version)
183                if self.ca_certs_file or self.ca_certs_path or self.ca_certs_data:
184                    ssl_context.load_verify_locations(self.ca_certs_file, self.ca_certs_path, self.ca_certs_data)
185                elif self.validate != ssl.CERT_NONE:
186                    ssl_context.load_default_certs(Purpose.SERVER_AUTH)
187
188            if self.certificate_file:
189                ssl_context.load_cert_chain(self.certificate_file, keyfile=self.private_key_file, password=self.private_key_password)
190            ssl_context.check_hostname = False
191            ssl_context.verify_mode = self.validate
192            for option in self.ssl_options:
193                ssl_context.options |= option
194
195            if self.ciphers:
196                try:
197                    ssl_context.set_ciphers(self.ciphers)
198                except ssl.SSLError:
199                    pass
200
201            if self.sni:
202                wrapped_socket = ssl_context.wrap_socket(connection.socket, server_side=False, do_handshake_on_connect=do_handshake, server_hostname=self.sni)
203            else:
204                wrapped_socket = ssl_context.wrap_socket(connection.socket, server_side=False, do_handshake_on_connect=do_handshake)
205            if log_enabled(NETWORK):
206                log(NETWORK, 'socket wrapped with SSL using SSLContext for <%s>', connection)
207        else:
208            if self.version is None and hasattr(ssl, 'PROTOCOL_SSLv23'):
209                self.version = ssl.PROTOCOL_SSLv23
210            if self.ciphers:
211                try:
212
213                    wrapped_socket = ssl.wrap_socket(connection.socket,
214                                                     keyfile=self.private_key_file,
215                                                     certfile=self.certificate_file,
216                                                     server_side=False,
217                                                     cert_reqs=self.validate,
218                                                     ssl_version=self.version,
219                                                     ca_certs=self.ca_certs_file,
220                                                     do_handshake_on_connect=do_handshake,
221                                                     ciphers=self.ciphers)
222                except ssl.SSLError:
223                    raise
224                except TypeError:  # in python2.6 no ciphers argument is present, failback to self.ciphers=None
225                    self.ciphers = None
226
227            if not self.ciphers:
228                wrapped_socket = ssl.wrap_socket(connection.socket,
229                                                 keyfile=self.private_key_file,
230                                                 certfile=self.certificate_file,
231                                                 server_side=False,
232                                                 cert_reqs=self.validate,
233                                                 ssl_version=self.version,
234                                                 ca_certs=self.ca_certs_file,
235                                                 do_handshake_on_connect=do_handshake)
236            if log_enabled(NETWORK):
237                log(NETWORK, 'socket wrapped with SSL for <%s>', connection)
238
239        if do_handshake and (self.validate == ssl.CERT_REQUIRED or self.validate == ssl.CERT_OPTIONAL):
240            check_hostname(wrapped_socket, connection.server.host, self.valid_names)
241
242        connection.socket = wrapped_socket
243        return
244
245    def start_tls(self, connection):
246        if connection.server.ssl:  # ssl already established at server level
247            return False
248
249        if (connection.tls_started and not connection._executing_deferred) or connection.strategy._outstanding or connection.sasl_in_progress:
250            # Per RFC 4513 (3.1.1)
251            if log_enabled(ERROR):
252                log(ERROR, "can't start tls because operations are in progress for <%s>", self)
253            return False
254        connection.starting_tls = True
255        if log_enabled(BASIC):
256            log(BASIC, 'starting tls for <%s>', connection)
257        if not connection.strategy.sync:
258            connection._awaiting_for_async_start_tls = True  # some flaky servers (OpenLDAP) doesn't return the extended response name in response
259        result = connection.extended('1.3.6.1.4.1.1466.20037')
260        if not connection.strategy.sync:
261            # asynchronous - _start_tls must be executed by the strategy
262            response = connection.get_response(result)
263            if response != (None, None):
264                if log_enabled(BASIC):
265                    log(BASIC, 'tls started for <%s>', connection)
266                return True
267            else:
268                if log_enabled(BASIC):
269                    log(BASIC, 'tls not started for <%s>', connection)
270                return False
271        else:
272            if connection.result['description'] not in ['success']:
273                # startTLS failed
274                connection.last_error = 'startTLS failed - ' + str(connection.result['description'])
275                if log_enabled(ERROR):
276                    log(ERROR, '%s for <%s>', connection.last_error, connection)
277                raise LDAPStartTLSError(connection.last_error)
278            if log_enabled(BASIC):
279                log(BASIC, 'tls started for <%s>', connection)
280            return self._start_tls(connection)
281
282    def _start_tls(self, connection):
283        try:
284            self.wrap_socket(connection, do_handshake=True)
285        except Exception as e:
286            connection.last_error = 'wrap socket error: ' + str(e)
287            if log_enabled(ERROR):
288                log(ERROR, 'error <%s> wrapping socket for TLS in <%s>', connection.last_error, connection)
289            raise start_tls_exception_factory(e)(connection.last_error)
290        finally:
291            connection.starting_tls = False
292
293        if connection.usage:
294            connection._usage.wrapped_sockets += 1
295        connection.tls_started = True
296        return True
297
298
299def check_hostname(sock, server_name, additional_names):
300    server_certificate = sock.getpeercert()
301    if log_enabled(NETWORK):
302        log(NETWORK, 'certificate found for %s: %s', sock, server_certificate)
303    if additional_names:
304        host_names = [server_name] + (additional_names if isinstance(additional_names, SEQUENCE_TYPES) else [additional_names])
305    else:
306        host_names = [server_name]
307
308    for host_name in host_names:
309        if not host_name:
310            continue
311        elif host_name == '*':
312            if log_enabled(NETWORK):
313                log(NETWORK, 'certificate matches * wildcard')
314            return  # valid
315
316        try:
317            match_hostname(server_certificate, host_name)  # raise CertificateError if certificate doesn't match server name
318            if log_enabled(NETWORK):
319                log(NETWORK, 'certificate matches host name <%s>', host_name)
320            return  # valid
321        except CertificateError as e:
322            if log_enabled(NETWORK):
323                log(NETWORK, str(e))
324
325    if log_enabled(ERROR):
326        log(ERROR, "hostname doesn't match certificate")
327    raise LDAPCertificateError("certificate %s doesn't match any name in %s " % (server_certificate, str(host_names)))
328