1# Copyright (c) 2014, 2021, Oracle and/or its affiliates.
2#
3# This program is free software; you can redistribute it and/or modify
4# it under the terms of the GNU General Public License, version 2.0, as
5# published by the Free Software Foundation.
6#
7# This program is also distributed with certain software (including
8# but not limited to OpenSSL) that is licensed under separate terms,
9# as designated in a particular file or component or in included license
10# documentation.  The authors of MySQL hereby grant you an
11# additional permission to link the program and your derivative works
12# with the separately licensed software that they have included with
13# MySQL.
14#
15# Without limiting anything contained in the foregoing, this file,
16# which is part of MySQL Connector/Python, is also subject to the
17# Universal FOSS Exception, version 1.0, a copy of which can be found at
18# http://oss.oracle.com/licenses/universal-foss-exception.
19#
20# This program is distributed in the hope that it will be useful, but
21# WITHOUT ANY WARRANTY; without even the implied warranty of
22# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
23# See the GNU General Public License, version 2.0, for more details.
24#
25# You should have received a copy of the GNU General Public License
26# along with this program; if not, write to the Free Software Foundation, Inc.,
27# 51 Franklin St, Fifth Floor, Boston, MA 02110-1301  USA
28
29"""Implementing support for MySQL Authentication Plugins"""
30
31from base64 import b64encode, b64decode
32from hashlib import sha1, sha256
33import getpass
34import hmac
35import logging
36import os
37import struct
38
39
40from urllib.parse import quote
41from uuid import uuid4
42
43try:
44    from cryptography.exceptions import UnsupportedAlgorithm
45    from cryptography.hazmat.primitives import hashes, serialization
46    from cryptography.hazmat.primitives.asymmetric import padding
47    CRYPTOGRAPHY_AVAILABLE = True
48except ImportError:
49    CRYPTOGRAPHY_AVAILABLE = False
50
51try:
52    import gssapi
53except:
54    gssapi = None
55
56from . import errors
57from .utils import (normalize_unicode_string as norm_ustr,
58                    validate_normalized_unicode_string as valid_norm)
59
60logging.getLogger(__name__).addHandler(logging.NullHandler())
61
62_LOGGER = logging.getLogger(__name__)
63
64
65class BaseAuthPlugin(object):
66    """Base class for authentication plugins
67
68
69    Classes inheriting from BaseAuthPlugin should implement the method
70    prepare_password(). When instantiating, auth_data argument is
71    required. The username, password and database are optional. The
72    ssl_enabled argument can be used to tell the plugin whether SSL is
73    active or not.
74
75    The method auth_response() method is used to retrieve the password
76    which was prepared by prepare_password().
77    """
78
79    requires_ssl = False
80    plugin_name = ''
81
82    def __init__(self, auth_data, username=None, password=None, database=None,
83                 ssl_enabled=False, instance=None):
84        """Initialization"""
85        self._auth_data = auth_data
86        self._username = username
87        self._password = password
88        self._database = database
89        self._ssl_enabled = ssl_enabled
90
91    def prepare_password(self):
92        """Prepares and returns password to be send to MySQL
93
94        This method needs to be implemented by classes inheriting from
95        this class. It is used by the auth_response() method.
96
97        Raises NotImplementedError.
98        """
99        raise NotImplementedError
100
101    def auth_response(self):
102        """Returns the prepared password to send to MySQL
103
104        Raises InterfaceError on errors. For example, when SSL is required
105        by not enabled.
106
107        Returns str
108        """
109        if self.requires_ssl and not self._ssl_enabled:
110            raise errors.InterfaceError("{name} requires SSL".format(
111                name=self.plugin_name))
112        return self.prepare_password()
113
114
115class MySQLNativePasswordAuthPlugin(BaseAuthPlugin):
116    """Class implementing the MySQL Native Password authentication plugin"""
117
118    requires_ssl = False
119    plugin_name = 'mysql_native_password'
120
121    def prepare_password(self):
122        """Prepares and returns password as native MySQL 4.1+ password"""
123        if not self._auth_data:
124            raise errors.InterfaceError("Missing authentication data (seed)")
125
126        if not self._password:
127            return b''
128        password = self._password
129
130        if isinstance(self._password, str):
131            password = self._password.encode('utf-8')
132        else:
133            password = self._password
134
135        auth_data = self._auth_data
136
137        hash4 = None
138        try:
139            hash1 = sha1(password).digest()
140            hash2 = sha1(hash1).digest()
141            hash3 = sha1(auth_data + hash2).digest()
142            xored = [h1 ^ h3 for (h1, h3) in zip(hash1, hash3)]
143            hash4 = struct.pack('20B', *xored)
144        except Exception as exc:
145            raise errors.InterfaceError(
146                "Failed scrambling password; {0}".format(exc))
147
148        return hash4
149
150
151class MySQLClearPasswordAuthPlugin(BaseAuthPlugin):
152    """Class implementing the MySQL Clear Password authentication plugin"""
153
154    requires_ssl = True
155    plugin_name = 'mysql_clear_password'
156
157    def prepare_password(self):
158        """Returns password as as clear text"""
159        if not self._password:
160            return b'\x00'
161        password = self._password
162
163        if isinstance(password, str):
164            password = password.encode('utf8')
165
166        return password + b'\x00'
167
168
169class MySQLSHA256PasswordAuthPlugin(BaseAuthPlugin):
170    """Class implementing the MySQL SHA256 authentication plugin
171
172    Note that encrypting using RSA is not supported since the Python
173    Standard Library does not provide this OpenSSL functionality.
174    """
175
176    requires_ssl = True
177    plugin_name = 'sha256_password'
178
179    def prepare_password(self):
180        """Returns password as as clear text"""
181        if not self._password:
182            return b'\x00'
183        password = self._password
184
185        if isinstance(password, str):
186            password = password.encode('utf8')
187
188        return password + b'\x00'
189
190
191class MySQLCachingSHA2PasswordAuthPlugin(BaseAuthPlugin):
192    """Class implementing the MySQL caching_sha2_password authentication plugin
193
194    Note that encrypting using RSA is not supported since the Python
195    Standard Library does not provide this OpenSSL functionality.
196    """
197    requires_ssl = False
198    plugin_name = 'caching_sha2_password'
199    perform_full_authentication = 4
200    fast_auth_success = 3
201
202    def _scramble(self):
203        """ Returns a scramble of the password using a Nonce sent by the
204        server.
205
206        The scramble is of the form:
207        XOR(SHA2(password), SHA2(SHA2(SHA2(password)), Nonce))
208        """
209        if not self._auth_data:
210            raise errors.InterfaceError("Missing authentication data (seed)")
211
212        if not self._password:
213            return b''
214
215        password = self._password.encode('utf-8') \
216            if isinstance(self._password, str) else self._password
217        auth_data = self._auth_data
218
219        hash1 = sha256(password).digest()
220        hash2 = sha256()
221        hash2.update(sha256(hash1).digest())
222        hash2.update(auth_data)
223        hash2 = hash2.digest()
224        xored = [h1 ^ h2 for (h1, h2) in zip(hash1, hash2)]
225        hash3 = struct.pack('32B', *xored)
226
227        return hash3
228
229    def prepare_password(self):
230        if len(self._auth_data) > 1:
231            return self._scramble()
232        elif self._auth_data[0] == self.perform_full_authentication:
233            return self._full_authentication()
234        return None
235
236    def _full_authentication(self):
237        """Returns password as as clear text"""
238        if not self._ssl_enabled:
239            raise errors.InterfaceError("{name} requires SSL".format(
240                name=self.plugin_name))
241
242        if not self._password:
243            return b'\x00'
244        password = self._password
245
246        if isinstance(password, str):
247            password = password.encode('utf8')
248
249        return password + b'\x00'
250
251
252class MySQLLdapSaslPasswordAuthPlugin(BaseAuthPlugin):
253    """Class implementing the MySQL ldap sasl authentication plugin.
254
255    The MySQL's ldap sasl authentication plugin support two authentication
256    methods SCRAM-SHA-1 and GSSAPI (using Kerberos). This implementation only
257    support SCRAM-SHA-1 and SCRAM-SHA-256.
258
259    SCRAM-SHA-1 amd SCRAM-SHA-256
260        This method requires 2 messages from client and 2 responses from
261        server.
262
263        The first message from client will be generated by prepare_password(),
264        after receive the response from the server, it is required that this
265        response is passed back to auth_continue() which will return the
266        second message from the client. After send this second message to the
267        server, the second server respond needs to be passed to auth_finalize()
268        to finish the authentication process.
269    """
270    sasl_mechanisms = ['SCRAM-SHA-1', 'SCRAM-SHA-256', 'GSSAPI']
271    requires_ssl = False
272    plugin_name = 'authentication_ldap_sasl_client'
273    def_digest_mode = sha1
274    client_nonce = None
275    client_salt = None
276    server_salt = None
277    krb_service_principal = None
278    iterations = 0
279    server_auth_var = None
280
281    def _xor(self, bytes1, bytes2):
282        return bytes([b1 ^ b2 for b1, b2 in zip(bytes1, bytes2)])
283
284    def _hmac(self, password, salt):
285        digest_maker = hmac.new(password, salt, self.def_digest_mode)
286        return digest_maker.digest()
287
288    def _hi(self, password, salt, count):
289        """Prepares Hi
290        Hi(password, salt, iterations) where Hi(p,s,i) is defined as
291        PBKDF2 (HMAC, p, s, i, output length of H).
292
293        """
294        pw = password.encode()
295        hi = self._hmac(pw, salt + b'\x00\x00\x00\x01')
296        aux = hi
297        for _ in range(count - 1):
298            aux = self._hmac(pw, aux)
299            hi = self._xor(hi, aux)
300        return hi
301
302    def _normalize(self, string):
303        norm_str = norm_ustr(string)
304        broken_rule = valid_norm(norm_str)
305        if broken_rule is not None:
306            raise errors.InterfaceError("broken_rule: {}".format(broken_rule))
307            char, rule = broken_rule
308            raise errors.InterfaceError(
309                "Unable to normalice character: `{}` in `{}` due to {}"
310                "".format(char, string, rule))
311        return norm_str
312
313    def _first_message(self):
314        """This method generates the first message to the server to start the
315
316        The client-first message consists of a gs2-header,
317        the desired username, and a randomly generated client nonce cnonce.
318
319        The first message from the server has the form:
320            b'n,a=<user_name>,n=<user_name>,r=<client_nonce>
321
322        Returns client's first message
323        """
324        cfm_fprnat = "n,a={user_name},n={user_name},r={client_nonce}"
325        self.client_nonce = str(uuid4()).replace("-", "")
326        cfm = cfm_fprnat.format(user_name=self._normalize(self._username),
327                                client_nonce=self.client_nonce)
328
329        if isinstance(cfm, str):
330            cfm = cfm.encode('utf8')
331        return cfm
332
333    def _first_message_krb(self):
334        """Get a TGT Authentication request and initiates security context.
335
336        This method will contact the Kerberos KDC in order of obtain a TGT.
337        """
338        _LOGGER.debug("# user name: %s", self._username)
339        user_name = gssapi.raw.names.import_name(self._username.encode('utf8'),
340                                                 name_type=gssapi.NameType.user)
341
342        # Use defaults store = {'ccache': 'FILE:/tmp/krb5cc_1000'}#, 'keytab':'/etc/some.keytab' }
343        # Attempt to retrieve credential from default cache file.
344        try:
345            cred = gssapi.Credentials()
346            _LOGGER.debug("# Stored credentials found, if password was given it"
347                          " will be ignored.")
348            try:
349                # validate credentials has not expired.
350                cred.lifetime
351            except gssapi.raw.exceptions.ExpiredCredentialsError as err:
352                _LOGGER.warning(" Credentials has expired: %s", err)
353                cred.acquire(user_name)
354                raise errors.InterfaceError("Credentials has expired: {}".format(err))
355        except gssapi.raw.misc.GSSError as err:
356            if not self._password:
357                _LOGGER.error(" Unable to retrieve stored credentials: %s", err)
358                raise errors.InterfaceError(
359                    "Unable to retrieve stored credentials error: {}".format(err))
360            else:
361                try:
362                    _LOGGER.debug("# Attempt to retrieve credentials with "
363                                  "given password")
364                    acquire_cred_result = gssapi.raw.acquire_cred_with_password(
365                        user_name, self._password.encode('utf8'), usage="initiate")
366                    cred = acquire_cred_result[0]
367                except gssapi.raw.misc.GSSError as err:
368                    _LOGGER.error(" Unable to retrieve credentials with the given "
369                                  "password: %s", err)
370                    raise errors.ProgrammingError(
371                        "Unable to retrieve credentials with the given password: "
372                        "{}".format(err))
373
374        flags_l = (gssapi.RequirementFlag.mutual_authentication,
375                   gssapi.RequirementFlag.extended_error,
376                   gssapi.RequirementFlag.delegate_to_peer
377        )
378
379        if self.krb_service_principal:
380            service_principal = self.krb_service_principal
381        else:
382            service_principal = "ldap/ldapauth"
383        _LOGGER.debug("# service principal: %s", service_principal)
384        servk = gssapi.Name(service_principal, name_type=gssapi.NameType.kerberos_principal)
385        self.target_name = servk
386        self.ctx = gssapi.SecurityContext(name=servk,
387                                          creds=cred,
388                                          flags=sum(flags_l),
389                                          usage='initiate')
390
391        try:
392            initial_client_token = self.ctx.step()
393        except gssapi.raw.misc.GSSError as err:
394            _LOGGER.error("Unable to initiate security context: %s", err)
395            raise errors.InterfaceError("Unable to initiate security context: {}".format(err))
396
397        _LOGGER.debug("# initial client token: %s", initial_client_token)
398        return initial_client_token
399
400
401    def auth_continue_krb(self, tgt_auth_challenge):
402        """Continue with the Kerberos TGT service request.
403
404        With the TGT authentication service given response generate a TGT
405        service request. This method must be invoked sequentially (in a loop)
406        until the security context is completed and an empty response needs to
407        be send to acknowledge the server.
408
409        Args:
410            tgt_auth_challenge the challenge for the negotiation.
411
412        Returns: tuple (bytearray TGS service request,
413                        bool True if context is completed otherwise False).
414        """
415        _LOGGER.debug("tgt_auth challenge: %s", tgt_auth_challenge)
416
417        resp = self.ctx.step(tgt_auth_challenge)
418        _LOGGER.debug("# context step response: %s", resp)
419        _LOGGER.debug("# context completed?: %s", self.ctx.complete)
420
421        return resp, self.ctx.complete
422
423    def auth_accept_close_handshake(self, message):
424        """Accept handshake and generate closing handshake message for server.
425
426        This method verifies the server authenticity from the given message
427        and included signature and generates the closing handshake for the
428        server.
429
430        When this method is invoked the security context is already established
431        and the client and server can send GSSAPI formated secure messages.
432
433        To finish the authentication handshake the server sends a message
434        with the security layer availability and the maximum buffer size.
435
436        Since the connector only uses the GSSAPI authentication mechanism to
437        authenticate the user with the server, the server will verify clients
438        message signature and terminate the GSSAPI authentication and send two
439        messages; an authentication acceptance b'\x01\x00\x00\x08\x01' and a
440        OK packet (that must be received after sent the returned message from
441        this method).
442
443        Args:
444            message a wrapped hssapi message from the server.
445
446        Returns: bytearray closing handshake message to be send to the server.
447        """
448        if not self.ctx.complete:
449            raise errors.ProgrammingError("Security context is not completed.")
450        _LOGGER.debug("# servers message: %s", message)
451        _LOGGER.debug("# GSSAPI flags in use: %s", self.ctx.actual_flags)
452        try:
453            unwraped = self.ctx.unwrap(message)
454            _LOGGER.debug("# unwraped: %s", unwraped)
455        except gssapi.raw.exceptions.BadMICError as err:
456            _LOGGER.debug("Unable to unwrap server message: %s", err)
457            raise errors.InterfaceError("Unable to unwrap server message: {}"
458                                 "".format(err))
459
460        _LOGGER.debug("# unwrapped server message: %s", unwraped)
461        # The message contents for the clients closing message:
462        #   - security level 1 byte, must be always 1.
463        #   - conciliated buffer size 3 bytes, without importance as no
464        #     further GSSAPI messages will be sends.
465        response = bytearray(b"\x01\x00\x00\00")
466        # Closing handshake must not be encrypted.
467        _LOGGER.debug("# message response: %s", response)
468        wraped = self.ctx.wrap(response, encrypt=False)
469        _LOGGER.debug("# wrapped message response: %s, length: %d",
470                      wraped[0], len(wraped[0]))
471
472        return wraped.message
473
474    def auth_response(self, krb_service_principal=None):
475        """This method will prepare the fist message to the server.
476
477        Returns bytes to send to the server as the first message.
478        """
479        auth_mechanism = self._auth_data.decode()
480        self.krb_service_principal = krb_service_principal
481        _LOGGER.debug("read_method_name_from_server: %s", auth_mechanism)
482        if auth_mechanism not in self.sasl_mechanisms:
483            raise errors.InterfaceError(
484                'The sasl authentication method "{}" requested from the server '
485                'is not supported. Only "{}" and "{}" are supported'.format(
486                    auth_mechanism, '", "'.join(self.sasl_mechanisms[:-1]),
487                    self.sasl_mechanisms[-1]))
488
489        if b'GSSAPI' in self._auth_data:
490            if not gssapi:
491                raise errors.ProgrammingError(
492                    "Module gssapi is required for GSSAPI authentication "
493                    "mechanism but was not found. Unable to authenticate "
494                    "with the server")
495            return self._first_message_krb()
496
497        if self._auth_data == b'SCRAM-SHA-256':
498            self.def_digest_mode = sha256
499
500        return self._first_message()
501
502    def _second_message(self):
503        """This method generates the second message to the server
504
505        Second message consist on the concatenation of the client and the
506        server nonce, and cproof.
507
508        c=<n,a=<user_name>>,r=<server_nonce>,p=<client_proof>
509        where:
510            <client_proof>: xor(<client_key>, <client_signature>)
511
512            <client_key>: hmac(salted_password, b"Client Key")
513            <client_signature>: hmac(<stored_key>, <auth_msg>)
514            <stored_key>: h(<client_key>)
515            <auth_msg>: <client_first_no_header>,<servers_first>,
516                        c=<client_header>,r=<server_nonce>
517            <client_first_no_header>: n=<username>r=<client_nonce>
518        """
519        if not self._auth_data:
520            raise errors.InterfaceError("Missing authentication data (seed)")
521
522        passw = self._normalize(self._password)
523        salted_password = self._hi(passw,
524                                   b64decode(self.server_salt),
525                                   self.iterations)
526
527        _LOGGER.debug("salted_password: %s",
528                      b64encode(salted_password).decode())
529
530        client_key = self._hmac(salted_password, b"Client Key")
531        _LOGGER.debug("client_key: %s", b64encode(client_key).decode())
532
533        stored_key = self.def_digest_mode(client_key).digest()
534        _LOGGER.debug("stored_key: %s", b64encode(stored_key).decode())
535
536        server_key = self._hmac(salted_password, b"Server Key")
537        _LOGGER.debug("server_key: %s", b64encode(server_key).decode())
538
539        client_first_no_header = ",".join([
540            "n={}".format(self._normalize(self._username)),
541            "r={}".format(self.client_nonce)])
542        _LOGGER.debug("client_first_no_header: %s", client_first_no_header)
543        auth_msg = ','.join([
544            client_first_no_header,
545            self.servers_first,
546            "c={}".format(b64encode("n,a={},".format(
547                self._normalize(self._username)).encode()).decode()),
548            "r={}".format(self.server_nonce)])
549        _LOGGER.debug("auth_msg: %s", auth_msg)
550
551        client_signature = self._hmac(stored_key, auth_msg.encode())
552        _LOGGER.debug("client_signature: %s",
553                      b64encode(client_signature).decode())
554
555        client_proof = self._xor(client_key, client_signature)
556        _LOGGER.debug("client_proof: %s", b64encode(client_proof).decode())
557
558        self.server_auth_var = b64encode(
559            self._hmac(server_key, auth_msg.encode())).decode()
560        _LOGGER.debug("server_auth_var: %s", self.server_auth_var)
561
562        client_header = b64encode(
563            "n,a={},".format(self._normalize(self._username)).encode()).decode()
564        msg = ",".join(["c={}".format(client_header),
565                        "r={}".format(self.server_nonce),
566                        "p={}".format(b64encode(client_proof).decode())])
567        _LOGGER.debug("second_message: %s", msg)
568        return msg.encode()
569
570    def _validate_first_reponse(self, servers_first):
571        """Validates first message from the server.
572
573        Extracts the server's salt and iterations from the servers 1st response.
574        First message from the server is in the form:
575            <server_salt>,i=<iterations>
576        """
577        if not servers_first or not isinstance(servers_first, (bytearray, bytes)):
578            raise errors.InterfaceError("Unexpected server message: {}"
579                                        "".format(servers_first))
580        try:
581            servers_first = servers_first.decode()
582            self.servers_first = servers_first
583            r_server_nonce, s_salt, i_counter = servers_first.split(",")
584        except ValueError:
585            raise errors.InterfaceError("Unexpected server message: {}"
586                                        "".format(servers_first))
587        if not r_server_nonce.startswith("r=") or \
588           not s_salt.startswith("s=") or \
589           not i_counter.startswith("i="):
590            raise errors.InterfaceError("Incomplete reponse from the server: {}"
591                                        "".format(servers_first))
592        if self.client_nonce in r_server_nonce:
593            self.server_nonce = r_server_nonce[2:]
594            _LOGGER.debug("server_nonce: %s", self.server_nonce)
595        else:
596            raise errors.InterfaceError("Unable to authenticate response: "
597                                        "response not well formed {}"
598                                        "".format(servers_first))
599        self.server_salt = s_salt[2:]
600        _LOGGER.debug("server_salt: %s length: %s", self.server_salt,
601                       len(self.server_salt))
602        try:
603            i_counter = i_counter[2:]
604            _LOGGER.debug("iterations: {}".format(i_counter))
605            self.iterations = int(i_counter)
606        except:
607            raise errors.InterfaceError("Unable to authenticate: iterations "
608                                        "not found {}".format(servers_first))
609
610    def auth_continue(self, servers_first_response):
611        """return the second message from the client.
612
613        Returns bytes to send to the server as the second message.
614        """
615        self._validate_first_reponse(servers_first_response)
616        return self._second_message()
617
618    def _validate_second_reponse(self, servers_second):
619        """Validates second message from the server.
620
621        The client and the server prove to each other they have the same Auth
622        variable.
623
624        The second message from the server consist of the server's proof:
625            server_proof = HMAC(<server_key>, <auth_msg>)
626            where:
627                <server_key>: hmac(<salted_password>, b"Server Key")
628                <auth_msg>: <client_first_no_header>,<servers_first>,
629                            c=<client_header>,r=<server_nonce>
630
631        Our server_proof must be equal to the Auth variable send on this second
632        response.
633        """
634        if not servers_second or not isinstance(servers_second, bytearray) or \
635           len(servers_second)<=2 or not servers_second.startswith(b"v="):
636            raise errors.InterfaceError("The server's proof is not well formated.")
637        server_var = servers_second[2:].decode()
638        _LOGGER.debug("server auth variable: %s", server_var)
639        return self.server_auth_var == server_var
640
641    def auth_finalize(self, servers_second_response):
642        """finalize the authentication process.
643
644        Raises errors.InterfaceError if the ervers_second_response is invalid.
645
646        Returns True in succesfull authentication False otherwise.
647        """
648        if not self._validate_second_reponse(servers_second_response):
649            raise errors.InterfaceError("Authentication failed: Unable to "
650                                        "proof server identity.")
651        return True
652
653
654class MySQLKerberosAuthPlugin(BaseAuthPlugin):
655    """Implement the MySQL Kerberos authentication plugin."""
656
657    plugin_name = "authentication_kerberos_client"
658    requires_ssl = False
659    context = None
660
661    @staticmethod
662    def get_user_from_credentials():
663        """Get user from credentials without realm."""
664        try:
665            creds = gssapi.Credentials(usage="initiate")
666            user = str(creds.name)
667            if user.find("@") != -1:
668                user, _ = user.split("@", 1)
669            return user
670        except gssapi.raw.misc.GSSError as err:
671            return getpass.getuser()
672
673    def _acquire_cred_with_password(self, upn):
674        """Acquire credentials through provided password."""
675        _LOGGER.debug(
676            "Attempt to acquire credentials through provided password"
677        )
678
679        username = gssapi.raw.names.import_name(
680            upn.encode("utf-8"),
681            name_type=gssapi.NameType.user
682        )
683
684        try:
685            acquire_cred_result = (
686                gssapi.raw.acquire_cred_with_password(
687                    username,
688                    self._password.encode("utf-8"),
689                    usage="initiate"
690                )
691            )
692        except gssapi.raw.misc.GSSError as err:
693            raise errors.ProgrammingError(
694                f"Unable to acquire credentials with the given password: {err}"
695            )
696        creds = acquire_cred_result[0]
697        return creds
698
699    def _parse_auth_data(self, packet):
700        """Parse authentication data.
701
702        Get the SPN and REALM from the authentication data packet.
703
704        Format:
705            SPN string length two bytes <B1> <B2> +
706            SPN string +
707            UPN realm string length two bytes <B1> <B2> +
708            UPN realm string
709
710        Returns:
711            tuple: With 'spn' and 'realm'.
712        """
713        spn_len = struct.unpack("<H", packet[:2])[0]
714        packet = packet[2:]
715
716        spn = struct.unpack(f"<{spn_len}s", packet[:spn_len])[0]
717        packet = packet[spn_len:]
718
719        realm_len = struct.unpack("<H", packet[:2])[0]
720        realm = struct.unpack(f"<{realm_len}s", packet[2:])[0]
721
722        return spn.decode(), realm.decode()
723
724    def prepare_password(self):
725        """Return password as as clear text."""
726        if not self._password:
727            return b"\x00"
728        password = self._password
729
730        if isinstance(password, str):
731            password = password.encode("utf8")
732
733        return password + b"\x00"
734
735    def auth_response(self, auth_data=None):
736        """Prepare the fist message to the server."""
737        spn = None
738        realm = None
739
740        if auth_data:
741            try:
742                spn, realm = self._parse_auth_data(auth_data)
743            except struct.error as err:
744                raise InterruptedError(f"Invalid authentication data: {err}")
745
746        if spn is None:
747            return self.prepare_password()
748
749        upn = f"{self._username}@{realm}" if self._username else None
750
751        _LOGGER.debug("Service Principal: %s", spn)
752        _LOGGER.debug("Realm: %s", realm)
753        _LOGGER.debug("Username: %s", self._username)
754
755        try:
756            # Attempt to retrieve credentials from default cache file
757            creds = gssapi.Credentials(usage="initiate")
758            creds_upn = str(creds.name)
759
760            _LOGGER.debug("Cached credentials found")
761            _LOGGER.debug("Cached credentials UPN: %s", creds_upn)
762
763            # Remove the realm from user
764            if creds_upn.find("@") != -1:
765                creds_user, creds_realm = creds_upn.split("@", 1)
766            else:
767                creds_user = creds_upn
768                creds_realm = None
769
770            upn = f"{self._username}@{realm}" if self._username else creds_upn
771
772            # The user from cached credentials matches with the given user?
773            if self._username and self._username != creds_user:
774                _LOGGER.debug(
775                    "The user from cached credentials doesn't match with the "
776                    "given user"
777                )
778                if self._password is not None:
779                    creds = self._acquire_cred_with_password(upn)
780            if (
781                creds_realm and creds_realm != realm and
782                self._password is not None
783            ):
784                creds = self._acquire_cred_with_password(upn)
785        except gssapi.raw.exceptions.ExpiredCredentialsError as err:
786            if upn and self._password is not None:
787                creds = self._acquire_cred_with_password(upn)
788            else:
789                raise errors.InterfaceError(f"Credentials has expired: {err}")
790        except gssapi.raw.misc.GSSError as err:
791            if upn and self._password is not None:
792                creds = self._acquire_cred_with_password(upn)
793            else:
794                raise errors.InterfaceError(
795                    f"Unable to retrieve cached credentials error: {err}"
796                )
797
798        flags = (
799            gssapi.RequirementFlag.mutual_authentication,
800            gssapi.RequirementFlag.extended_error,
801            gssapi.RequirementFlag.delegate_to_peer
802        )
803        name = gssapi.Name(
804            spn,
805            name_type=gssapi.NameType.kerberos_principal
806        )
807        cname = name.canonicalize(gssapi.MechType.kerberos)
808        self.context = gssapi.SecurityContext(
809            name=cname,
810            creds=creds,
811            flags=sum(flags),
812            usage="initiate"
813        )
814
815        try:
816            initial_client_token = self.context.step()
817        except gssapi.raw.misc.GSSError as err:
818            raise errors.InterfaceError(
819                f"Unable to initiate security context: {err}"
820            )
821
822        _LOGGER.debug("Initial client token: %s", initial_client_token)
823        return initial_client_token
824
825    def auth_continue(self, tgt_auth_challenge):
826        """Continue with the Kerberos TGT service request.
827
828        With the TGT authentication service given response generate a TGT
829        service request. This method must be invoked sequentially (in a loop)
830        until the security context is completed and an empty response needs to
831        be send to acknowledge the server.
832
833        Args:
834            tgt_auth_challenge: the challenge for the negotiation.
835
836        Returns:
837            tuple (bytearray TGS service request,
838            bool True if context is completed otherwise False).
839        """
840        _LOGGER.debug("tgt_auth challenge: %s", tgt_auth_challenge)
841
842        resp = self.context.step(tgt_auth_challenge)
843
844        _LOGGER.debug("Context step response: %s", resp)
845        _LOGGER.debug("Context completed?: %s", self.context.complete)
846
847        return resp, self.context.complete
848
849    def auth_accept_close_handshake(self, message):
850        """Accept handshake and generate closing handshake message for server.
851
852        This method verifies the server authenticity from the given message
853        and included signature and generates the closing handshake for the
854        server.
855
856        When this method is invoked the security context is already established
857        and the client and server can send GSSAPI formated secure messages.
858
859        To finish the authentication handshake the server sends a message
860        with the security layer availability and the maximum buffer size.
861
862        Since the connector only uses the GSSAPI authentication mechanism to
863        authenticate the user with the server, the server will verify clients
864        message signature and terminate the GSSAPI authentication and send two
865        messages; an authentication acceptance b'\x01\x00\x00\x08\x01' and a
866        OK packet (that must be received after sent the returned message from
867        this method).
868
869        Args:
870            message: a wrapped gssapi message from the server.
871
872        Returns:
873            bytearray (closing handshake message to be send to the server).
874        """
875        if not self.context.complete:
876            raise errors.ProgrammingError("Security context is not completed")
877        _LOGGER.debug("Server message: %s", message)
878        _LOGGER.debug("GSSAPI flags in use: %s", self.context.actual_flags)
879        try:
880            unwraped = self.context.unwrap(message)
881            _LOGGER.debug("Unwraped: %s", unwraped)
882        except gssapi.raw.exceptions.BadMICError as err:
883            _LOGGER.debug("Unable to unwrap server message: %s", err)
884            raise errors.InterfaceError(
885                "Unable to unwrap server message: {}".format(err)
886            )
887
888        _LOGGER.debug("Unwrapped server message: %s", unwraped)
889        # The message contents for the clients closing message:
890        #   - security level 1 byte, must be always 1.
891        #   - conciliated buffer size 3 bytes, without importance as no
892        #     further GSSAPI messages will be sends.
893        response = bytearray(b"\x01\x00\x00\00")
894        # Closing handshake must not be encrypted.
895        _LOGGER.debug("Message response: %s", response)
896        wraped = self.context.wrap(response, encrypt=False)
897        _LOGGER.debug(
898            "Wrapped message response: %s, length: %d",
899            wraped[0],
900            len(wraped[0])
901        )
902
903        return wraped.message
904
905
906class MySQL_OCI_AuthPlugin(BaseAuthPlugin):
907    """Implement the MySQL OCI IAM authentication plugin."""
908
909    plugin_name = "authentication_oci_client"
910    requires_ssl = False
911    context = None
912
913    def _prepare_auth_response(self, signature, oci_config):
914        """Prepare client's authentication response
915
916        Prepares client's authentication response in JSON format
917        Args:
918            signature:  server's nonce to be signed by client.
919            oci_config: OCI configuration object.
920
921        Returns:
922            JSON_STRING {"fingerprint": string, "signature": string}
923        """
924        signature_64 = b64encode(signature)
925        auth_response = {
926            "fingerprint": oci_config["fingerprint"],
927            "signature": signature_64.decode()
928        }
929        return repr(auth_response).replace(" ", "").replace("'", '"')
930
931    def _get_private_key(self, key_path):
932        """Get the private_key form the given location"""
933        if not CRYPTOGRAPHY_AVAILABLE:
934            raise errors.ProgrammingError(
935                "Package 'cryptography' is not installed"
936            )
937        try:
938            with open(os.path.expanduser(key_path), "rb") as key_file:
939                private_key = serialization.load_pem_private_key(
940                    key_file.read(),
941                    password=None,
942                )
943        except (TypeError, OSError, ValueError, UnsupportedAlgorithm) as err:
944            raise errors.ProgrammingError(
945                f'An error occurred while reading the API_KEY from "{key_path}":'
946                f" {err}")
947
948        return private_key
949
950    def _get_valid_oci_config(self, oci_path=None, profile_name="DEFAULT"):
951        """Get a valid OCI config from the given configuration file path"""
952        try:
953            from oci import config, exceptions
954        except ImportError:
955            raise errors.ProgrammingError(
956                'Package "oci" (Oracle Cloud Infrastructure Python SDK)'
957                ' is not installed.')
958        if not oci_path:
959            oci_path = config.DEFAULT_LOCATION
960
961        error_list = []
962        req_keys = {
963            "fingerprint": (lambda x: len(x) > 32),
964            "key_file": (lambda x: os.path.exists(os.path.expanduser(x)))
965        }
966
967        try:
968            # key_file is validated by oci.config if present
969            oci_config = config.from_file(oci_path, profile_name)
970            for req_key in req_keys:
971                try:
972                    # Verify parameter in req_key is present and valid
973                    if oci_config[req_key] \
974                       and not req_keys[req_key](oci_config[req_key]):
975                        error_list.append(f'Parameter "{req_key}" is invalid')
976                except KeyError as err:
977                    error_list.append(f'Does not contain parameter {req_key}')
978        except (
979            exceptions.ConfigFileNotFound,
980            exceptions.InvalidConfig,
981            exceptions.InvalidKeyFilePath,
982            exceptions.InvalidPrivateKey,
983            exceptions.MissingPrivateKeyPassphrase,
984            exceptions.ProfileNotFound
985        ) as err:
986            error_list.append(str(err))
987
988        # Raise errors if any
989        if error_list:
990            raise errors.ProgrammingError(
991                f'Invalid profile {profile_name} in: "{oci_path}". '
992                f" Errors found: {error_list}")
993
994        return oci_config
995
996    def auth_response(self, oci_path=None):
997        """Prepare authentication string for the server."""
998        if not CRYPTOGRAPHY_AVAILABLE:
999            raise errors.ProgrammingError(
1000                "Package 'cryptography' is not installed"
1001            )
1002        _LOGGER.debug("server nonce: %s, len %d",
1003                      self._auth_data, len(self._auth_data))
1004        _LOGGER.debug("OCI configuration file location: %s", oci_path)
1005
1006        oci_config = self._get_valid_oci_config(oci_path)
1007
1008        private_key = self._get_private_key(oci_config['key_file'])
1009        signature = private_key.sign(
1010            self._auth_data,
1011            padding.PKCS1v15(),
1012            hashes.SHA256()
1013        )
1014
1015        auth_response = self._prepare_auth_response(signature, oci_config)
1016        _LOGGER.debug("authentication response: %s", auth_response)
1017        return auth_response.encode()
1018
1019
1020def get_auth_plugin(plugin_name):
1021    """Return authentication class based on plugin name
1022
1023    This function returns the class for the authentication plugin plugin_name.
1024    The returned class is a subclass of BaseAuthPlugin.
1025
1026    Raises errors.NotSupportedError when plugin_name is not supported.
1027
1028    Returns subclass of BaseAuthPlugin.
1029    """
1030    for authclass in BaseAuthPlugin.__subclasses__():  # pylint: disable=E1101
1031        if authclass.plugin_name == plugin_name:
1032            return authclass
1033
1034    raise errors.NotSupportedError(
1035        "Authentication plugin '{0}' is not supported".format(plugin_name))
1036