1import hashlib
2import unicodedata
3from enum import IntEnum, unique
4from functools import wraps
5from operator import attrgetter
6from os import urandom
7from stringprep import (
8    in_table_a1,
9    in_table_b1,
10    in_table_c12,
11    in_table_c21_c22,
12    in_table_c3,
13    in_table_c4,
14    in_table_c5,
15    in_table_c6,
16    in_table_c7,
17    in_table_c8,
18    in_table_c9,
19    in_table_d1,
20    in_table_d2,
21)
22from uuid import uuid4
23
24from asn1crypto.x509 import Certificate
25
26from scramp.utils import b64dec, b64enc, h, hi, hmac, uenc, xor
27
28# https://tools.ietf.org/html/rfc5802
29# https://www.rfc-editor.org/rfc/rfc7677.txt
30
31
32@unique
33class ClientStage(IntEnum):
34    get_client_first = 1
35    set_server_first = 2
36    get_client_final = 3
37    set_server_final = 4
38
39
40@unique
41class ServerStage(IntEnum):
42    set_client_first = 1
43    get_server_first = 2
44    set_client_final = 3
45    get_server_final = 4
46
47
48def _check_stage(Stages, current_stage, next_stage):
49    if current_stage is None:
50        if next_stage != 1:
51            raise ScramException(f"The method {Stages(1).name} must be called first.")
52    elif current_stage == 4:
53        raise ScramException("The authentication sequence has already finished.")
54    elif next_stage != current_stage + 1:
55        raise ScramException(
56            f"The next method to be called is "
57            f"{Stages(current_stage + 1).name}, not this method."
58        )
59
60
61class ScramException(Exception):
62    def __init__(self, message, server_error=None):
63        super().__init__(message)
64        self.server_error = server_error
65
66    def __str__(self):
67        s_str = "" if self.server_error is None else f" {self.server_error}"
68        return super().__str__() + s_str
69
70
71MECHANISMS = (
72    "SCRAM-SHA-1",
73    "SCRAM-SHA-1-PLUS",
74    "SCRAM-SHA-256",
75    "SCRAM-SHA-256-PLUS",
76    "SCRAM-SHA-512",
77    "SCRAM-SHA-512-PLUS",
78    "SCRAM-SHA3-512",
79    "SCRAM-SHA3-512-PLUS",
80)
81
82
83CHANNEL_TYPES = (
84    "tls-server-end-point",
85    "tls-unique",
86    "tls-unique-for-telnet",
87)
88
89
90def make_channel_binding(name, ssl_socket):
91    if name == "tls-unique":
92        return (name, ssl_socket.get_channel_binding(name))
93    elif name == "tls-server-end-point":
94        cert_bin = ssl_socket.getpeercert(binary_form=True)
95        cert = Certificate.load(cert_bin)
96
97        # Find the hash algorithm to use according to
98        # https://tools.ietf.org/html/rfc5929#section-4
99        hash_algo = cert.hash_algo
100        if hash_algo in ("md5", "sha1"):
101            hash_algo = "sha256"
102
103        try:
104            hash_obj = hashlib.new(hash_algo)
105        except ValueError as e:
106            raise ScramException(
107                f"Hash algorithm {hash_algo} not supported by hashlib. {e}"
108            )
109        hash_obj.update(cert_bin)
110        return ("tls-server-end-point", hash_obj.digest())
111    else:
112        raise ScramException(f"Channel binding name {name} not recognized.")
113
114
115class ScramMechanism:
116    MECH_LOOKUP = {
117        "SCRAM-SHA-1": (hashlib.sha1, False, 4096, 0),
118        "SCRAM-SHA-1-PLUS": (hashlib.sha1, True, 4096, 1),
119        "SCRAM-SHA-256": (hashlib.sha256, False, 4096, 2),
120        "SCRAM-SHA-256-PLUS": (hashlib.sha256, True, 4096, 3),
121        "SCRAM-SHA-512": (hashlib.sha512, False, 4096, 4),
122        "SCRAM-SHA-512-PLUS": (hashlib.sha512, True, 4096, 5),
123        "SCRAM-SHA3-512": (hashlib.sha3_512, False, 10000, 6),
124        "SCRAM-SHA3-512-PLUS": (hashlib.sha3_512, True, 10000, 7),
125    }
126
127    def __init__(self, mechanism="SCRAM-SHA-256"):
128        if mechanism not in MECHANISMS:
129            raise ScramException(
130                f"The mechanism name '{mechanism}' is not supported. The "
131                f"supported mechanisms are {MECHANISMS}."
132            )
133        self.name = mechanism
134        (
135            self.hf,
136            self.use_binding,
137            self.iteration_count,
138            self.strength,
139        ) = self.MECH_LOOKUP[mechanism]
140
141    def make_auth_info(self, password, iteration_count=None, salt=None):
142        if iteration_count is None:
143            iteration_count = self.iteration_count
144        salt, stored_key, server_key = _make_auth_info(
145            self.hf, password, iteration_count, salt=salt
146        )
147        return salt, stored_key, server_key, iteration_count
148
149    def make_stored_server_keys(self, salted_password):
150        _, stored_key, server_key = _c_key_stored_key_s_key(self.hf, salted_password)
151        return stored_key, server_key
152
153    def make_server(self, auth_fn, channel_binding=None, s_nonce=None):
154        return ScramServer(
155            self, auth_fn, channel_binding=channel_binding, s_nonce=s_nonce
156        )
157
158
159def _make_auth_info(hf, password, i, salt=None):
160    if salt is None:
161        salt = urandom(16)
162
163    salted_password = _make_salted_password(hf, password, salt, i)
164    _, stored_key, server_key = _c_key_stored_key_s_key(hf, salted_password)
165    return salt, stored_key, server_key
166
167
168def _validate_channel_binding(channel_binding):
169    if channel_binding is None:
170        return
171
172    if not isinstance(channel_binding, tuple):
173        raise ScramException(
174            "The channel_binding parameter must either be None or a tuple."
175        )
176
177    if len(channel_binding) != 2:
178        raise ScramException(
179            "The channel_binding parameter must either be None or a tuple of "
180            "two elements (type, data)."
181        )
182
183    channel_type, channel_data = channel_binding
184    if channel_type not in CHANNEL_TYPES:
185        raise ScramException(
186            "The channel_binding parameter must either be None or a tuple "
187            "with the first element a str specifying one of the channel "
188            "types {CHANNEL_TYPES}."
189        )
190
191    if not isinstance(channel_data, bytes):
192        raise ScramException(
193            "The channel_binding parameter must either be None or a tuple "
194            "with the second element a bytes object containing the bind data."
195        )
196
197
198class ScramClient:
199    def __init__(
200        self, mechanisms, username, password, channel_binding=None, c_nonce=None
201    ):
202        if not isinstance(mechanisms, (list, tuple)):
203            raise ScramException(
204                "The 'mechanisms' parameter must be a list or tuple of "
205                "mechanism names."
206            )
207
208        _validate_channel_binding(channel_binding)
209
210        mechs = [ScramMechanism(m) for m in mechanisms]
211        mechs = [
212            m
213            for m in mechs
214            if channel_binding is not None
215            or (channel_binding is None and not m.use_binding)
216        ]
217        if len(mechs) == 0:
218            raise Exception("There are no suitable mechanisms in the list.")
219
220        mech = sorted(mechs, key=attrgetter("strength"))[-1]
221        self.hf, self.use_binding = mech.hf, mech.use_binding
222        self.mechanism_name = mech.name
223
224        if self.use_binding:
225            if channel_binding is None:
226                raise ScramException(
227                    "The channel_binding parameter can't be None if channel "
228                    "binding is required."
229                )
230        else:
231            if channel_binding is not None:
232                raise ScramException(
233                    "The channel_binding parameter must be None if channel "
234                    "binding is not required."
235                )
236
237        self.c_nonce = _make_nonce() if c_nonce is None else c_nonce
238        self.username = username
239        self.password = password
240        self.channel_binding = channel_binding
241        self.stage = None
242
243    def _set_stage(self, next_stage):
244        _check_stage(ClientStage, self.stage, next_stage)
245        self.stage = next_stage
246
247    def get_client_first(self):
248        self._set_stage(ClientStage.get_client_first)
249        self.client_first_bare, client_first = _get_client_first(
250            self.username, self.c_nonce, self.channel_binding
251        )
252        return client_first
253
254    def set_server_first(self, message):
255        self._set_stage(ClientStage.set_server_first)
256        self.server_first = message
257        self.auth_message, self.nonce, self.salt, self.iterations = _set_server_first(
258            message, self.c_nonce, self.client_first_bare, self.channel_binding
259        )
260
261    def get_client_final(self):
262        self._set_stage(ClientStage.get_client_final)
263        self.server_signature, cfinal = _get_client_final(
264            self.hf,
265            self.password,
266            self.salt,
267            self.iterations,
268            self.nonce,
269            self.auth_message,
270            self.channel_binding,
271        )
272        return cfinal
273
274    def set_server_final(self, message):
275        self._set_stage(ClientStage.set_server_final)
276        _set_server_final(message, self.server_signature)
277
278
279def set_error(f):
280    @wraps(f)
281    def wrapper(self, *args, **kwds):
282        try:
283            return f(self, *args, **kwds)
284        except ScramException as e:
285            if e.server_error is not None:
286                self.error = e.server_error
287                self.stage = ServerStage.set_client_final
288            raise e
289
290    return wrapper
291
292
293class ScramServer:
294    def __init__(self, mechanism, auth_fn, channel_binding=None, s_nonce=None):
295        self.m = mechanism
296
297        _validate_channel_binding(channel_binding)
298
299        if mechanism.use_binding:
300            if channel_binding is None:
301                raise ScramException(
302                    "The mechanism requires channel binding, and so "
303                    "channel_binding can't be None."
304                )
305        else:
306            if channel_binding is not None:
307                raise ScramException(
308                    "The mechanism does not support channel binding, and so "
309                    "channel_binding must be None."
310                )
311
312        self.channel_binding = channel_binding
313
314        self.s_nonce = _make_nonce() if s_nonce is None else s_nonce
315        self.auth_fn = auth_fn
316        self.stage = None
317        self.server_signature = None
318        self.error = None
319
320    def _set_stage(self, next_stage):
321        _check_stage(ServerStage, self.stage, next_stage)
322        self.stage = next_stage
323
324    @set_error
325    def set_client_first(self, client_first):
326        self._set_stage(ServerStage.set_client_first)
327        self.nonce, self.user, self.client_first_bare = _set_client_first(
328            client_first, self.s_nonce, self.channel_binding
329        )
330        salt, self.stored_key, self.server_key, self.i = self.auth_fn(self.user)
331        self.salt = b64enc(salt)
332
333    @set_error
334    def get_server_first(self):
335        self._set_stage(ServerStage.get_server_first)
336        self.auth_message, server_first = _get_server_first(
337            self.nonce, self.salt, self.i, self.client_first_bare, self.channel_binding
338        )
339        return server_first
340
341    @set_error
342    def set_client_final(self, client_final):
343        self._set_stage(ServerStage.set_client_final)
344        self.server_signature = _set_client_final(
345            self.m.hf,
346            client_final,
347            self.s_nonce,
348            self.stored_key,
349            self.server_key,
350            self.auth_message,
351            self.channel_binding,
352        )
353
354    @set_error
355    def get_server_final(self):
356        self._set_stage(ServerStage.get_server_final)
357        return _get_server_final(self.server_signature, self.error)
358
359
360def _make_nonce():
361    return str(uuid4()).replace("-", "")
362
363
364def _make_auth_message(nonce, client_first_bare, server_first, cbind_data):
365    cbind_input = b64enc(_make_cbind_input(cbind_data))
366    msg = client_first_bare, server_first, "c=" + cbind_input, "r=" + nonce
367    return ",".join(msg)
368
369
370def _make_salted_password(hf, password, salt, iterations):
371    return hi(hf, uenc(saslprep(password)), salt, iterations)
372
373
374def _c_key_stored_key_s_key(hf, salted_password):
375    client_key = hmac(hf, salted_password, b"Client Key")
376    stored_key = h(hf, client_key)
377    server_key = hmac(hf, salted_password, b"Server Key")
378
379    return client_key, stored_key, server_key
380
381
382def _check_client_key(hf, stored_key, auth_msg, proof):
383    client_signature = hmac(hf, stored_key, auth_msg)
384    client_key = xor(client_signature, b64dec(proof))
385    key = h(hf, client_key)
386
387    if key != stored_key:
388        raise ScramException("The client keys don't match.", SERVER_ERROR_INVALID_PROOF)
389
390
391def _make_gs2_header(channel_binding):
392    if channel_binding is None:
393        return "n,,"
394    else:
395        channel_type, _ = channel_binding
396        return f"p={channel_type},,"
397
398
399def _make_cbind_input(channel_binding):
400    gs2_header = _make_gs2_header(channel_binding).encode("ascii")
401    if channel_binding is None:
402        return gs2_header
403    else:
404        _, cbind_data = channel_binding
405        return gs2_header + cbind_data
406
407
408def _parse_message(msg):
409    return dict((e[0], e[2:]) for e in msg.split(",") if len(e) > 1)
410
411
412def _get_client_first(username, c_nonce, channel_binding):
413    try:
414        u = saslprep(username)
415    except ScramException as e:
416        raise ScramException(e.args[0], SERVER_ERROR_INVALID_USERNAME_ENCODING)
417
418    bare = ",".join((f"n={u}", f"r={c_nonce}"))
419    gs2_header = _make_gs2_header(channel_binding)
420    return bare, gs2_header + bare
421
422
423def _set_client_first(client_first, s_nonce, channel_binding):
424    first_comma = client_first.index(",")
425    second_comma = client_first.index(",", first_comma + 1)
426    gs2_header = client_first[:second_comma].split(",")
427    gs2_cbind_flag = gs2_header[0]
428    gs2_char = gs2_cbind_flag[0]
429
430    if gs2_char == "y":
431        if channel_binding is not None:
432            raise ScramException(
433                "Recieved GS2 flag 'y' which indicates that the client "
434                "doesn't think the server supports channel binding, but in "
435                "fact it does.",
436                SERVER_ERROR_SERVER_DOES_SUPPORT_CHANNEL_BINDING,
437            )
438
439    elif gs2_char == "n":
440        if channel_binding is not None:
441            raise ScramException(
442                "Received GS2 flag 'n' which indicates that the client "
443                "doesn't require channel binding, but the server does.",
444                SERVER_ERROR_SERVER_DOES_SUPPORT_CHANNEL_BINDING,
445            )
446
447    elif gs2_char == "p":
448        if channel_binding is None:
449            raise ScramException(
450                "Received GS2 flag 'p' which indicates that the client "
451                "requires channel binding, but the server does not.",
452                SERVER_ERROR_CHANNEL_BINDING_NOT_SUPPORTED,
453            )
454
455        channel_type, _ = channel_binding
456        cb_name = gs2_cbind_flag.split("=")[-1]
457        if cb_name != channel_type:
458            raise ScramException(
459                f"Received channel binding name {cb_name} but this server "
460                f"supports the channel binding name {channel_type}.",
461                SERVER_ERROR_UNSUPPORTED_CHANNEL_BINDING_TYPE,
462            )
463
464    else:
465        raise ScramException(
466            f"Received GS2 flag {gs2_char} which isn't recognized.",
467            SERVER_ERROR_OTHER_ERROR,
468        )
469
470    client_first_bare = client_first[second_comma + 1 :]
471    msg = _parse_message(client_first_bare)
472    c_nonce = msg["r"]
473    nonce = c_nonce + s_nonce
474    user = msg["n"]
475
476    return nonce, user, client_first_bare
477
478
479def _get_server_first(nonce, salt, iterations, client_first_bare, channel_binding):
480    sfirst = ",".join((f"r={nonce}", f"s={salt}", f"i={iterations}"))
481    auth_msg = _make_auth_message(nonce, client_first_bare, sfirst, channel_binding)
482    return auth_msg, sfirst
483
484
485def _set_server_first(server_first, c_nonce, client_first_bare, channel_binding):
486    msg = _parse_message(server_first)
487    if "e" in msg:
488        raise ScramException(f"The server returned the error: {msg['e']}")
489    nonce = msg["r"]
490    salt = msg["s"]
491    iterations = int(msg["i"])
492
493    if not nonce.startswith(c_nonce):
494        raise ScramException("Client nonce doesn't match.", SERVER_ERROR_OTHER_ERROR)
495
496    auth_msg = _make_auth_message(
497        nonce, client_first_bare, server_first, channel_binding
498    )
499    return auth_msg, nonce, salt, iterations
500
501
502def _get_client_final(
503    hf, password, salt_str, iterations, nonce, auth_msg_str, cbind_data
504):
505    salt = b64dec(salt_str)
506    salted_password = _make_salted_password(hf, password, salt, iterations)
507    client_key, stored_key, server_key = _c_key_stored_key_s_key(hf, salted_password)
508
509    auth_msg = uenc(auth_msg_str)
510
511    client_signature = hmac(hf, stored_key, auth_msg)
512    client_proof = xor(client_key, client_signature)
513    server_signature = hmac(hf, server_key, auth_msg)
514    cbind_input = _make_cbind_input(cbind_data)
515    msg = ["c=" + b64enc(cbind_input), "r=" + nonce, "p=" + b64enc(client_proof)]
516    return b64enc(server_signature), ",".join(msg)
517
518
519SERVER_ERROR_INVALID_ENCODING = "invalid-encoding"
520SERVER_ERROR_EXTENSIONS_NOT_SUPPORTED = "extensions-not-supported"
521SERVER_ERROR_INVALID_PROOF = "invalid-proof"
522SERVER_ERROR_INVALID_ENCODING = "invalid-encoding"
523SERVER_ERROR_CHANNEL_BINDINGS_DONT_MATCH = "channel-bindings-dont-match"
524SERVER_ERROR_SERVER_DOES_SUPPORT_CHANNEL_BINDING = "server-does-support-channel-binding"
525SERVER_ERROR_SERVER_DOES_NOT_SUPPORT_CHANNEL_BINDING = (
526    "server does not support channel binding"
527)
528SERVER_ERROR_CHANNEL_BINDING_NOT_SUPPORTED = "channel-binding-not-supported"
529SERVER_ERROR_UNSUPPORTED_CHANNEL_BINDING_TYPE = "unsupported-channel-binding-type"
530SERVER_ERROR_UNKNOWN_USER = "unknown-user"
531SERVER_ERROR_INVALID_USERNAME_ENCODING = "invalid-username-encoding"
532SERVER_ERROR_NO_RESOURCES = "no-resources"
533SERVER_ERROR_OTHER_ERROR = "other-error"
534
535
536def _set_client_final(
537    hf, client_final, s_nonce, stored_key, server_key, auth_msg_str, cbind_data
538):
539    auth_msg = uenc(auth_msg_str)
540
541    msg = _parse_message(client_final)
542    nonce = msg["r"]
543    proof = msg["p"]
544    channel_binding = msg["c"]
545    if not b64dec(channel_binding) == _make_cbind_input(cbind_data):
546        raise ScramException(
547            "The channel bindings don't match.",
548            SERVER_ERROR_CHANNEL_BINDINGS_DONT_MATCH,
549        )
550
551    if not nonce.endswith(s_nonce):
552        raise ScramException("Server nonce doesn't match.", SERVER_ERROR_OTHER_ERROR)
553
554    _check_client_key(hf, stored_key, auth_msg, proof)
555
556    sig = hmac(hf, server_key, auth_msg)
557    return b64enc(sig)
558
559
560def _get_server_final(server_signature, error):
561    return f"v={server_signature}" if error is None else f"e={error}"
562
563
564def _set_server_final(message, server_signature):
565    msg = _parse_message(message)
566    if "e" in msg:
567        raise ScramException(f"The server returned the error: {msg['e']}")
568
569    if server_signature != msg["v"]:
570        raise ScramException(
571            "The server signature doesn't match.", SERVER_ERROR_OTHER_ERROR
572        )
573
574
575def saslprep(source):
576    # mapping stage
577    #   - map non-ascii spaces to U+0020 (stringprep C.1.2)
578    #   - strip 'commonly mapped to nothing' chars (stringprep B.1)
579    data = "".join(" " if in_table_c12(c) else c for c in source if not in_table_b1(c))
580
581    # normalize to KC form
582    data = unicodedata.normalize("NFKC", data)
583    if not data:
584        return ""
585
586    # check for invalid bi-directional strings.
587    # stringprep requires the following:
588    #   - chars in C.8 must be prohibited.
589    #   - if any R/AL chars in string:
590    #       - no L chars allowed in string
591    #       - first and last must be R/AL chars
592    # this checks if start/end are R/AL chars. if so, prohibited loop
593    # will forbid all L chars. if not, prohibited loop will forbid all
594    # R/AL chars instead. in both cases, prohibited loop takes care of C.8.
595    is_ral_char = in_table_d1
596    if is_ral_char(data[0]):
597        if not is_ral_char(data[-1]):
598            raise ScramException(
599                "malformed bidi sequence", SERVER_ERROR_INVALID_ENCODING
600            )
601        # forbid L chars within R/AL sequence.
602        is_forbidden_bidi_char = in_table_d2
603    else:
604        # forbid R/AL chars if start not setup correctly; L chars allowed.
605        is_forbidden_bidi_char = is_ral_char
606
607    # check for prohibited output
608    # stringprep tables A.1, B.1, C.1.2, C.2 - C.9
609    for c in data:
610        # check for chars mapping stage should have removed
611        assert not in_table_b1(c), "failed to strip B.1 in mapping stage"
612        assert not in_table_c12(c), "failed to replace C.1.2 in mapping stage"
613
614        # check for forbidden chars
615        for f, msg in (
616            (in_table_a1, "unassigned code points forbidden"),
617            (in_table_c21_c22, "control characters forbidden"),
618            (in_table_c3, "private use characters forbidden"),
619            (in_table_c4, "non-char code points forbidden"),
620            (in_table_c5, "surrogate codes forbidden"),
621            (in_table_c6, "non-plaintext chars forbidden"),
622            (in_table_c7, "non-canonical chars forbidden"),
623            (in_table_c8, "display-modifying/deprecated chars forbidden"),
624            (in_table_c9, "tagged characters forbidden"),
625            (is_forbidden_bidi_char, "forbidden bidi character"),
626        ):
627            if f(c):
628                raise ScramException(msg, SERVER_ERROR_INVALID_ENCODING)
629
630    return data
631