1import hashlib 2import unicodedata 3from enum import IntEnum, unique 4from functools import wraps 5from operator import attrgetter 6from os import urandom 7from stringprep import ( 8 in_table_a1, 9 in_table_b1, 10 in_table_c12, 11 in_table_c21_c22, 12 in_table_c3, 13 in_table_c4, 14 in_table_c5, 15 in_table_c6, 16 in_table_c7, 17 in_table_c8, 18 in_table_c9, 19 in_table_d1, 20 in_table_d2, 21) 22from uuid import uuid4 23 24from asn1crypto.x509 import Certificate 25 26from scramp.utils import b64dec, b64enc, h, hi, hmac, uenc, xor 27 28# https://tools.ietf.org/html/rfc5802 29# https://www.rfc-editor.org/rfc/rfc7677.txt 30 31 32@unique 33class ClientStage(IntEnum): 34 get_client_first = 1 35 set_server_first = 2 36 get_client_final = 3 37 set_server_final = 4 38 39 40@unique 41class ServerStage(IntEnum): 42 set_client_first = 1 43 get_server_first = 2 44 set_client_final = 3 45 get_server_final = 4 46 47 48def _check_stage(Stages, current_stage, next_stage): 49 if current_stage is None: 50 if next_stage != 1: 51 raise ScramException(f"The method {Stages(1).name} must be called first.") 52 elif current_stage == 4: 53 raise ScramException("The authentication sequence has already finished.") 54 elif next_stage != current_stage + 1: 55 raise ScramException( 56 f"The next method to be called is " 57 f"{Stages(current_stage + 1).name}, not this method." 58 ) 59 60 61class ScramException(Exception): 62 def __init__(self, message, server_error=None): 63 super().__init__(message) 64 self.server_error = server_error 65 66 def __str__(self): 67 s_str = "" if self.server_error is None else f" {self.server_error}" 68 return super().__str__() + s_str 69 70 71MECHANISMS = ( 72 "SCRAM-SHA-1", 73 "SCRAM-SHA-1-PLUS", 74 "SCRAM-SHA-256", 75 "SCRAM-SHA-256-PLUS", 76 "SCRAM-SHA-512", 77 "SCRAM-SHA-512-PLUS", 78 "SCRAM-SHA3-512", 79 "SCRAM-SHA3-512-PLUS", 80) 81 82 83CHANNEL_TYPES = ( 84 "tls-server-end-point", 85 "tls-unique", 86 "tls-unique-for-telnet", 87) 88 89 90def make_channel_binding(name, ssl_socket): 91 if name == "tls-unique": 92 return (name, ssl_socket.get_channel_binding(name)) 93 elif name == "tls-server-end-point": 94 cert_bin = ssl_socket.getpeercert(binary_form=True) 95 cert = Certificate.load(cert_bin) 96 97 # Find the hash algorithm to use according to 98 # https://tools.ietf.org/html/rfc5929#section-4 99 hash_algo = cert.hash_algo 100 if hash_algo in ("md5", "sha1"): 101 hash_algo = "sha256" 102 103 try: 104 hash_obj = hashlib.new(hash_algo) 105 except ValueError as e: 106 raise ScramException( 107 f"Hash algorithm {hash_algo} not supported by hashlib. {e}" 108 ) 109 hash_obj.update(cert_bin) 110 return ("tls-server-end-point", hash_obj.digest()) 111 else: 112 raise ScramException(f"Channel binding name {name} not recognized.") 113 114 115class ScramMechanism: 116 MECH_LOOKUP = { 117 "SCRAM-SHA-1": (hashlib.sha1, False, 4096, 0), 118 "SCRAM-SHA-1-PLUS": (hashlib.sha1, True, 4096, 1), 119 "SCRAM-SHA-256": (hashlib.sha256, False, 4096, 2), 120 "SCRAM-SHA-256-PLUS": (hashlib.sha256, True, 4096, 3), 121 "SCRAM-SHA-512": (hashlib.sha512, False, 4096, 4), 122 "SCRAM-SHA-512-PLUS": (hashlib.sha512, True, 4096, 5), 123 "SCRAM-SHA3-512": (hashlib.sha3_512, False, 10000, 6), 124 "SCRAM-SHA3-512-PLUS": (hashlib.sha3_512, True, 10000, 7), 125 } 126 127 def __init__(self, mechanism="SCRAM-SHA-256"): 128 if mechanism not in MECHANISMS: 129 raise ScramException( 130 f"The mechanism name '{mechanism}' is not supported. The " 131 f"supported mechanisms are {MECHANISMS}." 132 ) 133 self.name = mechanism 134 ( 135 self.hf, 136 self.use_binding, 137 self.iteration_count, 138 self.strength, 139 ) = self.MECH_LOOKUP[mechanism] 140 141 def make_auth_info(self, password, iteration_count=None, salt=None): 142 if iteration_count is None: 143 iteration_count = self.iteration_count 144 salt, stored_key, server_key = _make_auth_info( 145 self.hf, password, iteration_count, salt=salt 146 ) 147 return salt, stored_key, server_key, iteration_count 148 149 def make_stored_server_keys(self, salted_password): 150 _, stored_key, server_key = _c_key_stored_key_s_key(self.hf, salted_password) 151 return stored_key, server_key 152 153 def make_server(self, auth_fn, channel_binding=None, s_nonce=None): 154 return ScramServer( 155 self, auth_fn, channel_binding=channel_binding, s_nonce=s_nonce 156 ) 157 158 159def _make_auth_info(hf, password, i, salt=None): 160 if salt is None: 161 salt = urandom(16) 162 163 salted_password = _make_salted_password(hf, password, salt, i) 164 _, stored_key, server_key = _c_key_stored_key_s_key(hf, salted_password) 165 return salt, stored_key, server_key 166 167 168def _validate_channel_binding(channel_binding): 169 if channel_binding is None: 170 return 171 172 if not isinstance(channel_binding, tuple): 173 raise ScramException( 174 "The channel_binding parameter must either be None or a tuple." 175 ) 176 177 if len(channel_binding) != 2: 178 raise ScramException( 179 "The channel_binding parameter must either be None or a tuple of " 180 "two elements (type, data)." 181 ) 182 183 channel_type, channel_data = channel_binding 184 if channel_type not in CHANNEL_TYPES: 185 raise ScramException( 186 "The channel_binding parameter must either be None or a tuple " 187 "with the first element a str specifying one of the channel " 188 "types {CHANNEL_TYPES}." 189 ) 190 191 if not isinstance(channel_data, bytes): 192 raise ScramException( 193 "The channel_binding parameter must either be None or a tuple " 194 "with the second element a bytes object containing the bind data." 195 ) 196 197 198class ScramClient: 199 def __init__( 200 self, mechanisms, username, password, channel_binding=None, c_nonce=None 201 ): 202 if not isinstance(mechanisms, (list, tuple)): 203 raise ScramException( 204 "The 'mechanisms' parameter must be a list or tuple of " 205 "mechanism names." 206 ) 207 208 _validate_channel_binding(channel_binding) 209 210 mechs = [ScramMechanism(m) for m in mechanisms] 211 mechs = [ 212 m 213 for m in mechs 214 if channel_binding is not None 215 or (channel_binding is None and not m.use_binding) 216 ] 217 if len(mechs) == 0: 218 raise Exception("There are no suitable mechanisms in the list.") 219 220 mech = sorted(mechs, key=attrgetter("strength"))[-1] 221 self.hf, self.use_binding = mech.hf, mech.use_binding 222 self.mechanism_name = mech.name 223 224 if self.use_binding: 225 if channel_binding is None: 226 raise ScramException( 227 "The channel_binding parameter can't be None if channel " 228 "binding is required." 229 ) 230 else: 231 if channel_binding is not None: 232 raise ScramException( 233 "The channel_binding parameter must be None if channel " 234 "binding is not required." 235 ) 236 237 self.c_nonce = _make_nonce() if c_nonce is None else c_nonce 238 self.username = username 239 self.password = password 240 self.channel_binding = channel_binding 241 self.stage = None 242 243 def _set_stage(self, next_stage): 244 _check_stage(ClientStage, self.stage, next_stage) 245 self.stage = next_stage 246 247 def get_client_first(self): 248 self._set_stage(ClientStage.get_client_first) 249 self.client_first_bare, client_first = _get_client_first( 250 self.username, self.c_nonce, self.channel_binding 251 ) 252 return client_first 253 254 def set_server_first(self, message): 255 self._set_stage(ClientStage.set_server_first) 256 self.server_first = message 257 self.auth_message, self.nonce, self.salt, self.iterations = _set_server_first( 258 message, self.c_nonce, self.client_first_bare, self.channel_binding 259 ) 260 261 def get_client_final(self): 262 self._set_stage(ClientStage.get_client_final) 263 self.server_signature, cfinal = _get_client_final( 264 self.hf, 265 self.password, 266 self.salt, 267 self.iterations, 268 self.nonce, 269 self.auth_message, 270 self.channel_binding, 271 ) 272 return cfinal 273 274 def set_server_final(self, message): 275 self._set_stage(ClientStage.set_server_final) 276 _set_server_final(message, self.server_signature) 277 278 279def set_error(f): 280 @wraps(f) 281 def wrapper(self, *args, **kwds): 282 try: 283 return f(self, *args, **kwds) 284 except ScramException as e: 285 if e.server_error is not None: 286 self.error = e.server_error 287 self.stage = ServerStage.set_client_final 288 raise e 289 290 return wrapper 291 292 293class ScramServer: 294 def __init__(self, mechanism, auth_fn, channel_binding=None, s_nonce=None): 295 self.m = mechanism 296 297 _validate_channel_binding(channel_binding) 298 299 if mechanism.use_binding: 300 if channel_binding is None: 301 raise ScramException( 302 "The mechanism requires channel binding, and so " 303 "channel_binding can't be None." 304 ) 305 else: 306 if channel_binding is not None: 307 raise ScramException( 308 "The mechanism does not support channel binding, and so " 309 "channel_binding must be None." 310 ) 311 312 self.channel_binding = channel_binding 313 314 self.s_nonce = _make_nonce() if s_nonce is None else s_nonce 315 self.auth_fn = auth_fn 316 self.stage = None 317 self.server_signature = None 318 self.error = None 319 320 def _set_stage(self, next_stage): 321 _check_stage(ServerStage, self.stage, next_stage) 322 self.stage = next_stage 323 324 @set_error 325 def set_client_first(self, client_first): 326 self._set_stage(ServerStage.set_client_first) 327 self.nonce, self.user, self.client_first_bare = _set_client_first( 328 client_first, self.s_nonce, self.channel_binding 329 ) 330 salt, self.stored_key, self.server_key, self.i = self.auth_fn(self.user) 331 self.salt = b64enc(salt) 332 333 @set_error 334 def get_server_first(self): 335 self._set_stage(ServerStage.get_server_first) 336 self.auth_message, server_first = _get_server_first( 337 self.nonce, self.salt, self.i, self.client_first_bare, self.channel_binding 338 ) 339 return server_first 340 341 @set_error 342 def set_client_final(self, client_final): 343 self._set_stage(ServerStage.set_client_final) 344 self.server_signature = _set_client_final( 345 self.m.hf, 346 client_final, 347 self.s_nonce, 348 self.stored_key, 349 self.server_key, 350 self.auth_message, 351 self.channel_binding, 352 ) 353 354 @set_error 355 def get_server_final(self): 356 self._set_stage(ServerStage.get_server_final) 357 return _get_server_final(self.server_signature, self.error) 358 359 360def _make_nonce(): 361 return str(uuid4()).replace("-", "") 362 363 364def _make_auth_message(nonce, client_first_bare, server_first, cbind_data): 365 cbind_input = b64enc(_make_cbind_input(cbind_data)) 366 msg = client_first_bare, server_first, "c=" + cbind_input, "r=" + nonce 367 return ",".join(msg) 368 369 370def _make_salted_password(hf, password, salt, iterations): 371 return hi(hf, uenc(saslprep(password)), salt, iterations) 372 373 374def _c_key_stored_key_s_key(hf, salted_password): 375 client_key = hmac(hf, salted_password, b"Client Key") 376 stored_key = h(hf, client_key) 377 server_key = hmac(hf, salted_password, b"Server Key") 378 379 return client_key, stored_key, server_key 380 381 382def _check_client_key(hf, stored_key, auth_msg, proof): 383 client_signature = hmac(hf, stored_key, auth_msg) 384 client_key = xor(client_signature, b64dec(proof)) 385 key = h(hf, client_key) 386 387 if key != stored_key: 388 raise ScramException("The client keys don't match.", SERVER_ERROR_INVALID_PROOF) 389 390 391def _make_gs2_header(channel_binding): 392 if channel_binding is None: 393 return "n,," 394 else: 395 channel_type, _ = channel_binding 396 return f"p={channel_type},," 397 398 399def _make_cbind_input(channel_binding): 400 gs2_header = _make_gs2_header(channel_binding).encode("ascii") 401 if channel_binding is None: 402 return gs2_header 403 else: 404 _, cbind_data = channel_binding 405 return gs2_header + cbind_data 406 407 408def _parse_message(msg): 409 return dict((e[0], e[2:]) for e in msg.split(",") if len(e) > 1) 410 411 412def _get_client_first(username, c_nonce, channel_binding): 413 try: 414 u = saslprep(username) 415 except ScramException as e: 416 raise ScramException(e.args[0], SERVER_ERROR_INVALID_USERNAME_ENCODING) 417 418 bare = ",".join((f"n={u}", f"r={c_nonce}")) 419 gs2_header = _make_gs2_header(channel_binding) 420 return bare, gs2_header + bare 421 422 423def _set_client_first(client_first, s_nonce, channel_binding): 424 first_comma = client_first.index(",") 425 second_comma = client_first.index(",", first_comma + 1) 426 gs2_header = client_first[:second_comma].split(",") 427 gs2_cbind_flag = gs2_header[0] 428 gs2_char = gs2_cbind_flag[0] 429 430 if gs2_char == "y": 431 if channel_binding is not None: 432 raise ScramException( 433 "Recieved GS2 flag 'y' which indicates that the client " 434 "doesn't think the server supports channel binding, but in " 435 "fact it does.", 436 SERVER_ERROR_SERVER_DOES_SUPPORT_CHANNEL_BINDING, 437 ) 438 439 elif gs2_char == "n": 440 if channel_binding is not None: 441 raise ScramException( 442 "Received GS2 flag 'n' which indicates that the client " 443 "doesn't require channel binding, but the server does.", 444 SERVER_ERROR_SERVER_DOES_SUPPORT_CHANNEL_BINDING, 445 ) 446 447 elif gs2_char == "p": 448 if channel_binding is None: 449 raise ScramException( 450 "Received GS2 flag 'p' which indicates that the client " 451 "requires channel binding, but the server does not.", 452 SERVER_ERROR_CHANNEL_BINDING_NOT_SUPPORTED, 453 ) 454 455 channel_type, _ = channel_binding 456 cb_name = gs2_cbind_flag.split("=")[-1] 457 if cb_name != channel_type: 458 raise ScramException( 459 f"Received channel binding name {cb_name} but this server " 460 f"supports the channel binding name {channel_type}.", 461 SERVER_ERROR_UNSUPPORTED_CHANNEL_BINDING_TYPE, 462 ) 463 464 else: 465 raise ScramException( 466 f"Received GS2 flag {gs2_char} which isn't recognized.", 467 SERVER_ERROR_OTHER_ERROR, 468 ) 469 470 client_first_bare = client_first[second_comma + 1 :] 471 msg = _parse_message(client_first_bare) 472 c_nonce = msg["r"] 473 nonce = c_nonce + s_nonce 474 user = msg["n"] 475 476 return nonce, user, client_first_bare 477 478 479def _get_server_first(nonce, salt, iterations, client_first_bare, channel_binding): 480 sfirst = ",".join((f"r={nonce}", f"s={salt}", f"i={iterations}")) 481 auth_msg = _make_auth_message(nonce, client_first_bare, sfirst, channel_binding) 482 return auth_msg, sfirst 483 484 485def _set_server_first(server_first, c_nonce, client_first_bare, channel_binding): 486 msg = _parse_message(server_first) 487 if "e" in msg: 488 raise ScramException(f"The server returned the error: {msg['e']}") 489 nonce = msg["r"] 490 salt = msg["s"] 491 iterations = int(msg["i"]) 492 493 if not nonce.startswith(c_nonce): 494 raise ScramException("Client nonce doesn't match.", SERVER_ERROR_OTHER_ERROR) 495 496 auth_msg = _make_auth_message( 497 nonce, client_first_bare, server_first, channel_binding 498 ) 499 return auth_msg, nonce, salt, iterations 500 501 502def _get_client_final( 503 hf, password, salt_str, iterations, nonce, auth_msg_str, cbind_data 504): 505 salt = b64dec(salt_str) 506 salted_password = _make_salted_password(hf, password, salt, iterations) 507 client_key, stored_key, server_key = _c_key_stored_key_s_key(hf, salted_password) 508 509 auth_msg = uenc(auth_msg_str) 510 511 client_signature = hmac(hf, stored_key, auth_msg) 512 client_proof = xor(client_key, client_signature) 513 server_signature = hmac(hf, server_key, auth_msg) 514 cbind_input = _make_cbind_input(cbind_data) 515 msg = ["c=" + b64enc(cbind_input), "r=" + nonce, "p=" + b64enc(client_proof)] 516 return b64enc(server_signature), ",".join(msg) 517 518 519SERVER_ERROR_INVALID_ENCODING = "invalid-encoding" 520SERVER_ERROR_EXTENSIONS_NOT_SUPPORTED = "extensions-not-supported" 521SERVER_ERROR_INVALID_PROOF = "invalid-proof" 522SERVER_ERROR_INVALID_ENCODING = "invalid-encoding" 523SERVER_ERROR_CHANNEL_BINDINGS_DONT_MATCH = "channel-bindings-dont-match" 524SERVER_ERROR_SERVER_DOES_SUPPORT_CHANNEL_BINDING = "server-does-support-channel-binding" 525SERVER_ERROR_SERVER_DOES_NOT_SUPPORT_CHANNEL_BINDING = ( 526 "server does not support channel binding" 527) 528SERVER_ERROR_CHANNEL_BINDING_NOT_SUPPORTED = "channel-binding-not-supported" 529SERVER_ERROR_UNSUPPORTED_CHANNEL_BINDING_TYPE = "unsupported-channel-binding-type" 530SERVER_ERROR_UNKNOWN_USER = "unknown-user" 531SERVER_ERROR_INVALID_USERNAME_ENCODING = "invalid-username-encoding" 532SERVER_ERROR_NO_RESOURCES = "no-resources" 533SERVER_ERROR_OTHER_ERROR = "other-error" 534 535 536def _set_client_final( 537 hf, client_final, s_nonce, stored_key, server_key, auth_msg_str, cbind_data 538): 539 auth_msg = uenc(auth_msg_str) 540 541 msg = _parse_message(client_final) 542 nonce = msg["r"] 543 proof = msg["p"] 544 channel_binding = msg["c"] 545 if not b64dec(channel_binding) == _make_cbind_input(cbind_data): 546 raise ScramException( 547 "The channel bindings don't match.", 548 SERVER_ERROR_CHANNEL_BINDINGS_DONT_MATCH, 549 ) 550 551 if not nonce.endswith(s_nonce): 552 raise ScramException("Server nonce doesn't match.", SERVER_ERROR_OTHER_ERROR) 553 554 _check_client_key(hf, stored_key, auth_msg, proof) 555 556 sig = hmac(hf, server_key, auth_msg) 557 return b64enc(sig) 558 559 560def _get_server_final(server_signature, error): 561 return f"v={server_signature}" if error is None else f"e={error}" 562 563 564def _set_server_final(message, server_signature): 565 msg = _parse_message(message) 566 if "e" in msg: 567 raise ScramException(f"The server returned the error: {msg['e']}") 568 569 if server_signature != msg["v"]: 570 raise ScramException( 571 "The server signature doesn't match.", SERVER_ERROR_OTHER_ERROR 572 ) 573 574 575def saslprep(source): 576 # mapping stage 577 # - map non-ascii spaces to U+0020 (stringprep C.1.2) 578 # - strip 'commonly mapped to nothing' chars (stringprep B.1) 579 data = "".join(" " if in_table_c12(c) else c for c in source if not in_table_b1(c)) 580 581 # normalize to KC form 582 data = unicodedata.normalize("NFKC", data) 583 if not data: 584 return "" 585 586 # check for invalid bi-directional strings. 587 # stringprep requires the following: 588 # - chars in C.8 must be prohibited. 589 # - if any R/AL chars in string: 590 # - no L chars allowed in string 591 # - first and last must be R/AL chars 592 # this checks if start/end are R/AL chars. if so, prohibited loop 593 # will forbid all L chars. if not, prohibited loop will forbid all 594 # R/AL chars instead. in both cases, prohibited loop takes care of C.8. 595 is_ral_char = in_table_d1 596 if is_ral_char(data[0]): 597 if not is_ral_char(data[-1]): 598 raise ScramException( 599 "malformed bidi sequence", SERVER_ERROR_INVALID_ENCODING 600 ) 601 # forbid L chars within R/AL sequence. 602 is_forbidden_bidi_char = in_table_d2 603 else: 604 # forbid R/AL chars if start not setup correctly; L chars allowed. 605 is_forbidden_bidi_char = is_ral_char 606 607 # check for prohibited output 608 # stringprep tables A.1, B.1, C.1.2, C.2 - C.9 609 for c in data: 610 # check for chars mapping stage should have removed 611 assert not in_table_b1(c), "failed to strip B.1 in mapping stage" 612 assert not in_table_c12(c), "failed to replace C.1.2 in mapping stage" 613 614 # check for forbidden chars 615 for f, msg in ( 616 (in_table_a1, "unassigned code points forbidden"), 617 (in_table_c21_c22, "control characters forbidden"), 618 (in_table_c3, "private use characters forbidden"), 619 (in_table_c4, "non-char code points forbidden"), 620 (in_table_c5, "surrogate codes forbidden"), 621 (in_table_c6, "non-plaintext chars forbidden"), 622 (in_table_c7, "non-canonical chars forbidden"), 623 (in_table_c8, "display-modifying/deprecated chars forbidden"), 624 (in_table_c9, "tagged characters forbidden"), 625 (is_forbidden_bidi_char, "forbidden bidi character"), 626 ): 627 if f(c): 628 raise ScramException(msg, SERVER_ERROR_INVALID_ENCODING) 629 630 return data 631