1# pylint:disable=line-too-long 2"""Parser for ssh public keys. Currently supports ssh-rsa, ssh-dsa, ssh-ed25519 and ssh-dss keys. 3 4import sys 5 6 7key_data = open("ssh-pubkey-file.pem").read() 8ssh_key = SSHKey(key_data) 9try: 10 ssh_key.parse() 11except InvalidKeyError: 12 print("Invalid key") 13 sys.exit(1) 14print(ssh_key.bits)""" 15 16from .exceptions import ( 17 InvalidKeyError, InvalidKeyLengthError, InvalidOptionNameError, InvalidOptionsError, InvalidTypeError, 18 MalformedDataError, MissingMandatoryOptionValueError, TooLongKeyError, TooShortKeyError, UnknownOptionNameError 19) 20from cryptography.hazmat.backends import default_backend 21from cryptography.hazmat.primitives.asymmetric.dsa import DSAParameterNumbers, DSAPublicNumbers 22from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers 23from urllib.parse import urlparse 24 25import base64 26import binascii 27import ecdsa 28import hashlib 29import re 30import struct 31import sys 32import warnings 33 34__all__ = ["AuthorizedKeysFile", "SSHKey"] 35 36 37class AuthorizedKeysFile: # pylint:disable=too-few-public-methods 38 """Represents a full authorized_keys file. 39 40 Comments and empty lines are ignored.""" 41 42 def __init__(self, file_obj, **kwargs): 43 self.keys = [] 44 self.parse(file_obj, **kwargs) 45 46 def parse(self, file_obj, **kwargs): 47 for line in file_obj: 48 line = line.strip() 49 if not line: 50 continue 51 if line.startswith("#"): 52 continue 53 ssh_key = SSHKey(line, **kwargs) 54 ssh_key.parse() 55 self.keys.append(ssh_key) 56 57 58class SSHKey: # pylint:disable=too-many-instance-attributes 59 """Represents a single SSH keypair. 60 61 ssh_key = SSHKey(key_data, strict=True) 62 ssh_key.parse() 63 64 strict=True (default) only allows keys ssh-keygen generates. Setting strict mode to false allows 65 all keys OpenSSH actually accepts, including highly insecure ones. For example, OpenSSH accepts 66 512-bit DSA keys and 64-bit RSA keys which are highly insecure.""" 67 68 DSA_MIN_LENGTH_STRICT = 1024 69 DSA_MAX_LENGTH_STRICT = 1024 70 DSA_MIN_LENGTH_LOOSE = 1 71 DSA_MAX_LENGTH_LOOSE = 3072 72 73 DSA_N_LENGTH = 160 74 75 ECDSA_CURVE_DATA = { 76 b"nistp256": (ecdsa.curves.NIST256p, hashlib.sha256), 77 b"nistp192": (ecdsa.curves.NIST192p, hashlib.sha256), 78 b"nistp224": (ecdsa.curves.NIST224p, hashlib.sha256), 79 b"nistp384": (ecdsa.curves.NIST384p, hashlib.sha384), 80 b"nistp521": (ecdsa.curves.NIST521p, hashlib.sha512), 81 } 82 83 RSA_MIN_LENGTH_STRICT = 1024 84 RSA_MAX_LENGTH_STRICT = 16384 85 RSA_MIN_LENGTH_LOOSE = 768 86 RSA_MAX_LENGTH_LOOSE = 16384 87 88 # Valid as of OpenSSH_8.3 89 # argument name, value is mandatory. Options are case-insensitive, but this list must be in lowercase. 90 OPTIONS_SPEC = [ 91 ("agent-forwarding", False), 92 ("cert-authority", False), 93 ("command", True), 94 ("environment", True), 95 ("expiry-time", True), 96 ("from", True), 97 ("no-agent-forwarding", False), 98 ("no-port-forwarding", False), 99 ("no-pty", False), 100 ("no-user-rc", False), 101 ("no-x11-forwarding", False), 102 ("permitlisten", True), 103 ("permitopen", True), 104 ("port-forwarding", False), 105 ("principals", True), 106 ("pty", False), 107 ("no-touch-required", False), 108 ("restrict", False), 109 ("tunnel", True), 110 ("user-rc", False), 111 ("x11-forwarding", False), 112 ] 113 OPTION_NAME_RE = re.compile("^[A-Za-z0-9-]+$") 114 115 INT_LEN = 4 116 117 FIELDS = ["rsa", "dsa", "ecdsa", "bits", "comment", "options", "options_raw", "key_type"] 118 119 def __init__(self, keydata=None, **kwargs): 120 self.keydata = keydata 121 self._decoded_key = None 122 self.rsa = None 123 self.dsa = None 124 self.ecdsa = None 125 self.bits = None 126 self.comment = None 127 self.options = None 128 self.options_raw = None 129 self.key_type = None 130 self.strict_mode = bool(kwargs.get("strict", True)) 131 self.skip_option_parsing = bool(kwargs.get("skip_option_parsing", False)) 132 self.disallow_options = bool(kwargs.get("disallow_options", False)) 133 if keydata: 134 try: 135 self.parse(keydata) 136 except (InvalidKeyError, NotImplementedError): 137 pass 138 139 def __str__(self): 140 return "Key type: %s, bits: %s, options: %s" % (self.key_type.decode(), self.bits, self.options) 141 142 def reset(self): 143 """Reset all data fields.""" 144 for field in self.FIELDS: 145 setattr(self, field, None) 146 147 def hash(self): 148 """Calculate md5 fingerprint. 149 150 Deprecated, use .hash_md5() instead.""" 151 warnings.warn("hash() is deprecated. Use hash_md5(), hash_sha256() or hash_sha512() instead.") 152 return self.hash_md5().replace(b"MD5:", b"") 153 154 def hash_md5(self): 155 """Calculate md5 fingerprint. 156 157 Shamelessly copied from http://stackoverflow.com/questions/6682815/deriving-an-ssh-fingerprint-from-a-public-key-in-python 158 159 For specification, see RFC4716, section 4.""" 160 fp_plain = hashlib.md5(self._decoded_key).hexdigest() 161 return "MD5:" + ':'.join(a + b for a, b in zip(fp_plain[::2], fp_plain[1::2])) 162 163 def hash_sha256(self): 164 """Calculate sha256 fingerprint.""" 165 fp_plain = hashlib.sha256(self._decoded_key).digest() 166 return (b"SHA256:" + base64.b64encode(fp_plain).replace(b"=", b"")).decode("utf-8") 167 168 def hash_sha512(self): 169 """Calculates sha512 fingerprint.""" 170 fp_plain = hashlib.sha512(self._decoded_key).digest() 171 return (b"SHA512:" + base64.b64encode(fp_plain).replace(b"=", b"")).decode("utf-8") 172 173 def _unpack_by_int(self, data, current_position): 174 """Returns a tuple with (location of next data field, contents of requested data field).""" 175 # Unpack length of data field 176 try: 177 requested_data_length = struct.unpack('>I', data[current_position:current_position + self.INT_LEN])[0] 178 except struct.error as ex: 179 raise MalformedDataError("Unable to unpack %s bytes from the data" % self.INT_LEN) from ex 180 181 # Move pointer to the beginning of the data field 182 current_position += self.INT_LEN 183 remaining_data_length = len(data[current_position:]) 184 185 if remaining_data_length < requested_data_length: 186 raise MalformedDataError( 187 "Requested %s bytes, but only %s bytes available." % (requested_data_length, remaining_data_length) 188 ) 189 190 next_data = data[current_position:current_position + requested_data_length] 191 # Move pointer to the end of the data field 192 current_position += requested_data_length 193 return current_position, next_data 194 195 @classmethod 196 def _parse_long(cls, data): 197 """Calculate two's complement.""" 198 if sys.version < '3': 199 # this does not exist in python 3 - undefined-variable disabled to make pylint happier. 200 ret = long(0) # pylint:disable=undefined-variable 201 for byte in data: 202 ret = (ret << 8) + ord(byte) 203 else: 204 ret = 0 205 for byte in data: 206 ret = (ret << 8) + byte 207 return ret 208 209 def _split_key(self, data): 210 options_raw = None 211 # Terribly inefficient way to remove options, but hey, it works. 212 if not data.startswith("ssh-") and not data.startswith("ecdsa-") and not data.startswith("sk-"): 213 quote_open = False 214 for i, character in enumerate(data): 215 if character == '"': # only double quotes are allowed, no need to care about single quotes 216 quote_open = not quote_open 217 if quote_open: 218 continue 219 if character == " ": 220 # Data begins after the first space 221 options_raw = data[:i] 222 data = data[i + 1:] 223 break 224 else: 225 raise MalformedDataError("Couldn't find beginning of the key data") 226 key_parts = data.strip().split(None, 2) 227 if len(key_parts) < 2: # Key type and content are mandatory fields. 228 raise InvalidKeyError("Unexpected key format: at least type and base64 encoded value is required") 229 if len(key_parts) == 3: 230 self.comment = key_parts[2] 231 key_parts = key_parts[0:2] 232 if options_raw: 233 # Populate and parse options field. 234 self.options_raw = options_raw 235 if not self.skip_option_parsing: 236 self.options = self.parse_options(self.options_raw) 237 else: 238 # Set empty defaults for fields 239 self.options_raw = None 240 self.options = {} 241 return key_parts 242 243 @classmethod 244 def decode_key(cls, pubkey_content): 245 """Decode base64 coded part of the key.""" 246 try: 247 decoded_key = base64.b64decode(pubkey_content.encode("ascii")) 248 except (TypeError, binascii.Error) as ex: 249 raise MalformedDataError("Unable to decode the key") from ex 250 return decoded_key 251 252 @classmethod 253 def _bits_in_number(cls, number): 254 return len(format(number, "b")) 255 256 def parse_options(self, options): 257 """Parses ssh options string.""" 258 quote_open = False 259 parsed_options = {} 260 261 def parse_add_single_option(opt): 262 """Parses and validates a single option, and adds it to parsed_options field.""" 263 if "=" in opt: 264 opt_name, opt_value = opt.split("=", 1) 265 opt_value = opt_value.replace('"', '') 266 else: 267 opt_name = opt 268 opt_value = True 269 if " " in opt_name or not self.OPTION_NAME_RE.match(opt_name): 270 raise InvalidOptionNameError("%s is not a valid option name." % opt_name) 271 if self.strict_mode: 272 for valid_opt_name, value_required in self.OPTIONS_SPEC: 273 if opt_name.lower() == valid_opt_name: 274 if value_required and opt_value is True: 275 raise MissingMandatoryOptionValueError("%s is missing a mandatory value." % opt_name) 276 break 277 else: 278 raise UnknownOptionNameError("%s is an unrecognized option name." % opt_name) 279 if opt_name not in parsed_options: 280 parsed_options[opt_name] = [] 281 parsed_options[opt_name].append(opt_value) 282 283 start_of_current_opt = 0 284 i = 1 # Need to be set for empty options strings 285 for i, character in enumerate(options): 286 if character == '"': # only double quotes are allowed, no need to care about single quotes 287 quote_open = not quote_open 288 if quote_open: 289 continue 290 if character == ",": 291 opt = options[start_of_current_opt:i] 292 parse_add_single_option(opt) 293 start_of_current_opt = i + 1 294 # Data begins after the first space 295 if start_of_current_opt + 1 != i: 296 opt = options[start_of_current_opt:] 297 parse_add_single_option(opt) 298 if quote_open: 299 raise InvalidOptionsError("Unbalanced quotes.") 300 return parsed_options 301 302 def _process_ssh_rsa(self, data): 303 """Parses ssh-rsa public keys.""" 304 current_position, raw_e = self._unpack_by_int(data, 0) 305 current_position, raw_n = self._unpack_by_int(data, current_position) 306 307 unpacked_e = self._parse_long(raw_e) 308 unpacked_n = self._parse_long(raw_n) 309 310 self.rsa = RSAPublicNumbers(unpacked_e, unpacked_n).public_key(default_backend()) 311 self.bits = self.rsa.key_size 312 313 if self.strict_mode: 314 min_length = self.RSA_MIN_LENGTH_STRICT 315 max_length = self.RSA_MAX_LENGTH_STRICT 316 else: 317 min_length = self.RSA_MIN_LENGTH_LOOSE 318 max_length = self.RSA_MAX_LENGTH_LOOSE 319 if self.bits < min_length: 320 raise TooShortKeyError( 321 "%s key data can not be shorter than %s bits (was %s)" % (self.key_type.decode(), min_length, self.bits) 322 ) 323 if self.bits > max_length: 324 raise TooLongKeyError( 325 "%s key data can not be longer than %s bits (was %s)" % (self.key_type.decode(), max_length, self.bits) 326 ) 327 return current_position 328 329 def _process_ssh_dss(self, data): 330 """Parses ssh-dsa public keys.""" 331 data_fields = {} 332 current_position = 0 333 for item in ("p", "q", "g", "y"): 334 current_position, value = self._unpack_by_int(data, current_position) 335 data_fields[item] = self._parse_long(value) 336 337 q_bits = self._bits_in_number(data_fields["q"]) 338 p_bits = self._bits_in_number(data_fields["p"]) 339 if q_bits != self.DSA_N_LENGTH: 340 raise InvalidKeyError("Incorrect DSA key parameters: bits(p)=%s, q=%s" % (self.bits, q_bits)) 341 if self.strict_mode: 342 min_length = self.DSA_MIN_LENGTH_STRICT 343 max_length = self.DSA_MAX_LENGTH_STRICT 344 else: 345 min_length = self.DSA_MIN_LENGTH_LOOSE 346 max_length = self.DSA_MAX_LENGTH_LOOSE 347 if p_bits < min_length: 348 raise TooShortKeyError( 349 "%s key can not be shorter than %s bits (was %s)" % (self.key_type.decode(), min_length, p_bits) 350 ) 351 if p_bits > max_length: 352 raise TooLongKeyError( 353 "%s key data can not be longer than %s bits (was %s)" % (self.key_type.decode(), max_length, p_bits) 354 ) 355 356 dsa_parameters = DSAParameterNumbers(data_fields["p"], data_fields["q"], data_fields["g"]) 357 self.dsa = DSAPublicNumbers(data_fields["y"], dsa_parameters).public_key(default_backend()) 358 self.bits = self.dsa.key_size 359 360 return current_position 361 362 def _process_ecdsa_sha(self, data): 363 """Parses ecdsa-sha public keys.""" 364 current_position, curve_information = self._unpack_by_int(data, 0) 365 if curve_information not in self.ECDSA_CURVE_DATA: 366 raise NotImplementedError("Invalid curve type: %s" % curve_information) 367 curve, hash_algorithm = self.ECDSA_CURVE_DATA[curve_information] 368 369 current_position, key_data = self._unpack_by_int(data, current_position) 370 try: 371 # data starts with \x04, which should be discarded. 372 ecdsa_key = ecdsa.VerifyingKey.from_string(key_data[1:], curve, hash_algorithm) 373 except AssertionError as ex: 374 raise InvalidKeyError("Invalid ecdsa key") from ex 375 self.bits = int(curve_information.replace(b"nistp", b"")) 376 self.ecdsa = ecdsa_key 377 return current_position 378 379 def _process_ed25516(self, data): 380 """Parses ed25516 keys. 381 382 There is no (apparent) way to validate ed25519 keys. This only 383 checks data length (256 bits), but does not try to validate 384 the key in any way.""" 385 386 current_position, verifying_key = self._unpack_by_int(data, 0) 387 verifying_key_length = len(verifying_key) * 8 388 verifying_key = self._parse_long(verifying_key) 389 390 if verifying_key < 0: 391 raise InvalidKeyError("ed25519 verifying key must be >0.") 392 393 self.bits = verifying_key_length 394 if self.bits != 256: 395 raise InvalidKeyLengthError("ed25519 keys must be 256 bits (was %s bits)" % self.bits) 396 return current_position 397 398 def _validate_application_string(self, application): 399 """Validates Application string. 400 401 Has to be an URL starting with "ssh:". See ssh-keygen(1).""" 402 403 try: 404 parsed_url = urlparse(application) 405 except ValueError as err: 406 raise InvalidKeyError("Application string: %s" % err) from err 407 if parsed_url.scheme != b"ssh": 408 raise InvalidKeyError('Application string must begin with "ssh:"') 409 410 def _process_sk_ecdsa_sha(self, data): 411 """Parses sk_ecdsa-sha public keys.""" 412 current_position = self._process_ecdsa_sha(data) 413 current_position, application = self._unpack_by_int(data, current_position) 414 self._validate_application_string(application) 415 return current_position 416 417 def _process_sk_ed25519(self, data): 418 """Parses sk_ed25519 public keys.""" 419 current_position = self._process_ed25516(data) 420 current_position, application = self._unpack_by_int(data, current_position) 421 self._validate_application_string(application) 422 return current_position 423 424 def _process_key(self, data): 425 if self.key_type == b"ssh-rsa": 426 return self._process_ssh_rsa(data) 427 if self.key_type == b"ssh-dss": 428 return self._process_ssh_dss(data) 429 if self.key_type.strip().startswith(b"ecdsa-sha"): 430 return self._process_ecdsa_sha(data) 431 if self.key_type == b"ssh-ed25519": 432 return self._process_ed25516(data) 433 if self.key_type.strip().startswith(b"sk-ecdsa-sha"): 434 return self._process_sk_ecdsa_sha(data) 435 if self.key_type.strip().startswith(b"sk-ssh-ed25519"): 436 return self._process_sk_ed25519(data) 437 raise NotImplementedError("Invalid key type: %s" % self.key_type.decode()) 438 439 def parse(self, keydata=None): 440 """Validates SSH public key. 441 442 Throws exception for invalid keys. Otherwise returns None. 443 444 Populates key_type, bits and bits fields. 445 446 For rsa keys, see field "rsa" for raw public key data. 447 For dsa keys, see field "dsa". 448 For ecdsa keys, see field "ecdsa".""" 449 if keydata is None: 450 if self.keydata is None: 451 raise ValueError("Key data must be supplied either in constructor or to parse()") 452 keydata = self.keydata 453 else: 454 self.reset() 455 self.keydata = keydata 456 457 if keydata.startswith("---- BEGIN SSH2 PUBLIC KEY ----"): 458 # SSH2 key format 459 key_type = None # There is no redundant key-type field - skip comparing plain-text and encoded data. 460 pubkey_content = "".join([line for line in keydata.split("\n") if ":" not in line and "----" not in line]) 461 else: 462 key_parts = self._split_key(keydata) 463 key_type = key_parts[0] 464 pubkey_content = key_parts[1] 465 466 self._decoded_key = self.decode_key(pubkey_content) 467 468 # Check key type 469 current_position, unpacked_key_type = self._unpack_by_int(self._decoded_key, 0) 470 if key_type is not None and key_type != unpacked_key_type.decode(): 471 raise InvalidTypeError("Keytype mismatch: %s != %s" % (key_type, unpacked_key_type.decode())) 472 473 self.key_type = unpacked_key_type 474 475 key_data_length = self._process_key(self._decoded_key[current_position:]) 476 current_position = current_position + key_data_length 477 478 if current_position != len(self._decoded_key): 479 raise MalformedDataError("Leftover data: %s bytes" % (len(self._decoded_key) - current_position)) 480 481 if self.disallow_options and self.options: 482 raise InvalidOptionsError("Options are disallowed.") 483