1# pylint:disable=line-too-long
2"""Parser for ssh public keys. Currently supports ssh-rsa, ssh-dsa, ssh-ed25519 and ssh-dss keys.
3
4import sys
5
6
7key_data = open("ssh-pubkey-file.pem").read()
8ssh_key = SSHKey(key_data)
9try:
10    ssh_key.parse()
11except InvalidKeyError:
12    print("Invalid key")
13    sys.exit(1)
14print(ssh_key.bits)"""
15
16from .exceptions import (
17    InvalidKeyError, InvalidKeyLengthError, InvalidOptionNameError, InvalidOptionsError, InvalidTypeError,
18    MalformedDataError, MissingMandatoryOptionValueError, TooLongKeyError, TooShortKeyError, UnknownOptionNameError
19)
20from cryptography.hazmat.backends import default_backend
21from cryptography.hazmat.primitives.asymmetric.dsa import DSAParameterNumbers, DSAPublicNumbers
22from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
23from urllib.parse import urlparse
24
25import base64
26import binascii
27import ecdsa
28import hashlib
29import re
30import struct
31import sys
32import warnings
33
34__all__ = ["AuthorizedKeysFile", "SSHKey"]
35
36
37class AuthorizedKeysFile:  # pylint:disable=too-few-public-methods
38    """Represents a full authorized_keys file.
39
40    Comments and empty lines are ignored."""
41
42    def __init__(self, file_obj, **kwargs):
43        self.keys = []
44        self.parse(file_obj, **kwargs)
45
46    def parse(self, file_obj, **kwargs):
47        for line in file_obj:
48            line = line.strip()
49            if not line:
50                continue
51            if line.startswith("#"):
52                continue
53            ssh_key = SSHKey(line, **kwargs)
54            ssh_key.parse()
55            self.keys.append(ssh_key)
56
57
58class SSHKey:  # pylint:disable=too-many-instance-attributes
59    """Represents a single SSH keypair.
60
61    ssh_key = SSHKey(key_data, strict=True)
62    ssh_key.parse()
63
64    strict=True (default) only allows keys ssh-keygen generates. Setting strict mode to false allows
65    all keys OpenSSH actually accepts, including highly insecure ones. For example, OpenSSH accepts
66    512-bit DSA keys and 64-bit RSA keys which are highly insecure."""
67
68    DSA_MIN_LENGTH_STRICT = 1024
69    DSA_MAX_LENGTH_STRICT = 1024
70    DSA_MIN_LENGTH_LOOSE = 1
71    DSA_MAX_LENGTH_LOOSE = 3072
72
73    DSA_N_LENGTH = 160
74
75    ECDSA_CURVE_DATA = {
76        b"nistp256": (ecdsa.curves.NIST256p, hashlib.sha256),
77        b"nistp192": (ecdsa.curves.NIST192p, hashlib.sha256),
78        b"nistp224": (ecdsa.curves.NIST224p, hashlib.sha256),
79        b"nistp384": (ecdsa.curves.NIST384p, hashlib.sha384),
80        b"nistp521": (ecdsa.curves.NIST521p, hashlib.sha512),
81    }
82
83    RSA_MIN_LENGTH_STRICT = 1024
84    RSA_MAX_LENGTH_STRICT = 16384
85    RSA_MIN_LENGTH_LOOSE = 768
86    RSA_MAX_LENGTH_LOOSE = 16384
87
88    # Valid as of OpenSSH_8.3
89    # argument name, value is mandatory. Options are case-insensitive, but this list must be in lowercase.
90    OPTIONS_SPEC = [
91        ("agent-forwarding", False),
92        ("cert-authority", False),
93        ("command", True),
94        ("environment", True),
95        ("expiry-time", True),
96        ("from", True),
97        ("no-agent-forwarding", False),
98        ("no-port-forwarding", False),
99        ("no-pty", False),
100        ("no-user-rc", False),
101        ("no-x11-forwarding", False),
102        ("permitlisten", True),
103        ("permitopen", True),
104        ("port-forwarding", False),
105        ("principals", True),
106        ("pty", False),
107        ("no-touch-required", False),
108        ("restrict", False),
109        ("tunnel", True),
110        ("user-rc", False),
111        ("x11-forwarding", False),
112    ]
113    OPTION_NAME_RE = re.compile("^[A-Za-z0-9-]+$")
114
115    INT_LEN = 4
116
117    FIELDS = ["rsa", "dsa", "ecdsa", "bits", "comment", "options", "options_raw", "key_type"]
118
119    def __init__(self, keydata=None, **kwargs):
120        self.keydata = keydata
121        self._decoded_key = None
122        self.rsa = None
123        self.dsa = None
124        self.ecdsa = None
125        self.bits = None
126        self.comment = None
127        self.options = None
128        self.options_raw = None
129        self.key_type = None
130        self.strict_mode = bool(kwargs.get("strict", True))
131        self.skip_option_parsing = bool(kwargs.get("skip_option_parsing", False))
132        self.disallow_options = bool(kwargs.get("disallow_options", False))
133        if keydata:
134            try:
135                self.parse(keydata)
136            except (InvalidKeyError, NotImplementedError):
137                pass
138
139    def __str__(self):
140        return "Key type: %s, bits: %s, options: %s" % (self.key_type.decode(), self.bits, self.options)
141
142    def reset(self):
143        """Reset all data fields."""
144        for field in self.FIELDS:
145            setattr(self, field, None)
146
147    def hash(self):
148        """Calculate md5 fingerprint.
149
150        Deprecated, use .hash_md5() instead."""
151        warnings.warn("hash() is deprecated. Use hash_md5(), hash_sha256() or hash_sha512() instead.")
152        return self.hash_md5().replace(b"MD5:", b"")
153
154    def hash_md5(self):
155        """Calculate md5 fingerprint.
156
157        Shamelessly copied from http://stackoverflow.com/questions/6682815/deriving-an-ssh-fingerprint-from-a-public-key-in-python
158
159        For specification, see RFC4716, section 4."""
160        fp_plain = hashlib.md5(self._decoded_key).hexdigest()
161        return "MD5:" + ':'.join(a + b for a, b in zip(fp_plain[::2], fp_plain[1::2]))
162
163    def hash_sha256(self):
164        """Calculate sha256 fingerprint."""
165        fp_plain = hashlib.sha256(self._decoded_key).digest()
166        return (b"SHA256:" + base64.b64encode(fp_plain).replace(b"=", b"")).decode("utf-8")
167
168    def hash_sha512(self):
169        """Calculates sha512 fingerprint."""
170        fp_plain = hashlib.sha512(self._decoded_key).digest()
171        return (b"SHA512:" + base64.b64encode(fp_plain).replace(b"=", b"")).decode("utf-8")
172
173    def _unpack_by_int(self, data, current_position):
174        """Returns a tuple with (location of next data field, contents of requested data field)."""
175        # Unpack length of data field
176        try:
177            requested_data_length = struct.unpack('>I', data[current_position:current_position + self.INT_LEN])[0]
178        except struct.error as ex:
179            raise MalformedDataError("Unable to unpack %s bytes from the data" % self.INT_LEN) from ex
180
181        # Move pointer to the beginning of the data field
182        current_position += self.INT_LEN
183        remaining_data_length = len(data[current_position:])
184
185        if remaining_data_length < requested_data_length:
186            raise MalformedDataError(
187                "Requested %s bytes, but only %s bytes available." % (requested_data_length, remaining_data_length)
188            )
189
190        next_data = data[current_position:current_position + requested_data_length]
191        # Move pointer to the end of the data field
192        current_position += requested_data_length
193        return current_position, next_data
194
195    @classmethod
196    def _parse_long(cls, data):
197        """Calculate two's complement."""
198        if sys.version < '3':
199            # this does not exist in python 3 - undefined-variable disabled to make pylint happier.
200            ret = long(0)  # pylint:disable=undefined-variable
201            for byte in data:
202                ret = (ret << 8) + ord(byte)
203        else:
204            ret = 0
205            for byte in data:
206                ret = (ret << 8) + byte
207        return ret
208
209    def _split_key(self, data):
210        options_raw = None
211        # Terribly inefficient way to remove options, but hey, it works.
212        if not data.startswith("ssh-") and not data.startswith("ecdsa-") and not data.startswith("sk-"):
213            quote_open = False
214            for i, character in enumerate(data):
215                if character == '"':  # only double quotes are allowed, no need to care about single quotes
216                    quote_open = not quote_open
217                if quote_open:
218                    continue
219                if character == " ":
220                    # Data begins after the first space
221                    options_raw = data[:i]
222                    data = data[i + 1:]
223                    break
224            else:
225                raise MalformedDataError("Couldn't find beginning of the key data")
226        key_parts = data.strip().split(None, 2)
227        if len(key_parts) < 2:  # Key type and content are mandatory fields.
228            raise InvalidKeyError("Unexpected key format: at least type and base64 encoded value is required")
229        if len(key_parts) == 3:
230            self.comment = key_parts[2]
231            key_parts = key_parts[0:2]
232        if options_raw:
233            # Populate and parse options field.
234            self.options_raw = options_raw
235            if not self.skip_option_parsing:
236                self.options = self.parse_options(self.options_raw)
237        else:
238            # Set empty defaults for fields
239            self.options_raw = None
240            self.options = {}
241        return key_parts
242
243    @classmethod
244    def decode_key(cls, pubkey_content):
245        """Decode base64 coded part of the key."""
246        try:
247            decoded_key = base64.b64decode(pubkey_content.encode("ascii"))
248        except (TypeError, binascii.Error) as ex:
249            raise MalformedDataError("Unable to decode the key") from ex
250        return decoded_key
251
252    @classmethod
253    def _bits_in_number(cls, number):
254        return len(format(number, "b"))
255
256    def parse_options(self, options):
257        """Parses ssh options string."""
258        quote_open = False
259        parsed_options = {}
260
261        def parse_add_single_option(opt):
262            """Parses and validates a single option, and adds it to parsed_options field."""
263            if "=" in opt:
264                opt_name, opt_value = opt.split("=", 1)
265                opt_value = opt_value.replace('"', '')
266            else:
267                opt_name = opt
268                opt_value = True
269            if " " in opt_name or not self.OPTION_NAME_RE.match(opt_name):
270                raise InvalidOptionNameError("%s is not a valid option name." % opt_name)
271            if self.strict_mode:
272                for valid_opt_name, value_required in self.OPTIONS_SPEC:
273                    if opt_name.lower() == valid_opt_name:
274                        if value_required and opt_value is True:
275                            raise MissingMandatoryOptionValueError("%s is missing a mandatory value." % opt_name)
276                        break
277                else:
278                    raise UnknownOptionNameError("%s is an unrecognized option name." % opt_name)
279            if opt_name not in parsed_options:
280                parsed_options[opt_name] = []
281            parsed_options[opt_name].append(opt_value)
282
283        start_of_current_opt = 0
284        i = 1  # Need to be set for empty options strings
285        for i, character in enumerate(options):
286            if character == '"':  # only double quotes are allowed, no need to care about single quotes
287                quote_open = not quote_open
288            if quote_open:
289                continue
290            if character == ",":
291                opt = options[start_of_current_opt:i]
292                parse_add_single_option(opt)
293                start_of_current_opt = i + 1
294                # Data begins after the first space
295        if start_of_current_opt + 1 != i:
296            opt = options[start_of_current_opt:]
297            parse_add_single_option(opt)
298        if quote_open:
299            raise InvalidOptionsError("Unbalanced quotes.")
300        return parsed_options
301
302    def _process_ssh_rsa(self, data):
303        """Parses ssh-rsa public keys."""
304        current_position, raw_e = self._unpack_by_int(data, 0)
305        current_position, raw_n = self._unpack_by_int(data, current_position)
306
307        unpacked_e = self._parse_long(raw_e)
308        unpacked_n = self._parse_long(raw_n)
309
310        self.rsa = RSAPublicNumbers(unpacked_e, unpacked_n).public_key(default_backend())
311        self.bits = self.rsa.key_size
312
313        if self.strict_mode:
314            min_length = self.RSA_MIN_LENGTH_STRICT
315            max_length = self.RSA_MAX_LENGTH_STRICT
316        else:
317            min_length = self.RSA_MIN_LENGTH_LOOSE
318            max_length = self.RSA_MAX_LENGTH_LOOSE
319        if self.bits < min_length:
320            raise TooShortKeyError(
321                "%s key data can not be shorter than %s bits (was %s)" % (self.key_type.decode(), min_length, self.bits)
322            )
323        if self.bits > max_length:
324            raise TooLongKeyError(
325                "%s key data can not be longer than %s bits (was %s)" % (self.key_type.decode(), max_length, self.bits)
326            )
327        return current_position
328
329    def _process_ssh_dss(self, data):
330        """Parses ssh-dsa public keys."""
331        data_fields = {}
332        current_position = 0
333        for item in ("p", "q", "g", "y"):
334            current_position, value = self._unpack_by_int(data, current_position)
335            data_fields[item] = self._parse_long(value)
336
337        q_bits = self._bits_in_number(data_fields["q"])
338        p_bits = self._bits_in_number(data_fields["p"])
339        if q_bits != self.DSA_N_LENGTH:
340            raise InvalidKeyError("Incorrect DSA key parameters: bits(p)=%s, q=%s" % (self.bits, q_bits))
341        if self.strict_mode:
342            min_length = self.DSA_MIN_LENGTH_STRICT
343            max_length = self.DSA_MAX_LENGTH_STRICT
344        else:
345            min_length = self.DSA_MIN_LENGTH_LOOSE
346            max_length = self.DSA_MAX_LENGTH_LOOSE
347        if p_bits < min_length:
348            raise TooShortKeyError(
349                "%s key can not be shorter than %s bits (was %s)" % (self.key_type.decode(), min_length, p_bits)
350            )
351        if p_bits > max_length:
352            raise TooLongKeyError(
353                "%s key data can not be longer than %s bits (was %s)" % (self.key_type.decode(), max_length, p_bits)
354            )
355
356        dsa_parameters = DSAParameterNumbers(data_fields["p"], data_fields["q"], data_fields["g"])
357        self.dsa = DSAPublicNumbers(data_fields["y"], dsa_parameters).public_key(default_backend())
358        self.bits = self.dsa.key_size
359
360        return current_position
361
362    def _process_ecdsa_sha(self, data):
363        """Parses ecdsa-sha public keys."""
364        current_position, curve_information = self._unpack_by_int(data, 0)
365        if curve_information not in self.ECDSA_CURVE_DATA:
366            raise NotImplementedError("Invalid curve type: %s" % curve_information)
367        curve, hash_algorithm = self.ECDSA_CURVE_DATA[curve_information]
368
369        current_position, key_data = self._unpack_by_int(data, current_position)
370        try:
371            # data starts with \x04, which should be discarded.
372            ecdsa_key = ecdsa.VerifyingKey.from_string(key_data[1:], curve, hash_algorithm)
373        except AssertionError as ex:
374            raise InvalidKeyError("Invalid ecdsa key") from ex
375        self.bits = int(curve_information.replace(b"nistp", b""))
376        self.ecdsa = ecdsa_key
377        return current_position
378
379    def _process_ed25516(self, data):
380        """Parses ed25516 keys.
381
382        There is no (apparent) way to validate ed25519 keys. This only
383        checks data length (256 bits), but does not try to validate
384        the key in any way."""
385
386        current_position, verifying_key = self._unpack_by_int(data, 0)
387        verifying_key_length = len(verifying_key) * 8
388        verifying_key = self._parse_long(verifying_key)
389
390        if verifying_key < 0:
391            raise InvalidKeyError("ed25519 verifying key must be >0.")
392
393        self.bits = verifying_key_length
394        if self.bits != 256:
395            raise InvalidKeyLengthError("ed25519 keys must be 256 bits (was %s bits)" % self.bits)
396        return current_position
397
398    def _validate_application_string(self, application):
399        """Validates Application string.
400
401        Has to be an URL starting with "ssh:". See ssh-keygen(1)."""
402
403        try:
404            parsed_url = urlparse(application)
405        except ValueError as err:
406            raise InvalidKeyError("Application string: %s" % err) from err
407        if parsed_url.scheme != b"ssh":
408            raise InvalidKeyError('Application string must begin with "ssh:"')
409
410    def _process_sk_ecdsa_sha(self, data):
411        """Parses sk_ecdsa-sha public keys."""
412        current_position = self._process_ecdsa_sha(data)
413        current_position, application = self._unpack_by_int(data, current_position)
414        self._validate_application_string(application)
415        return current_position
416
417    def _process_sk_ed25519(self, data):
418        """Parses sk_ed25519 public keys."""
419        current_position = self._process_ed25516(data)
420        current_position, application = self._unpack_by_int(data, current_position)
421        self._validate_application_string(application)
422        return current_position
423
424    def _process_key(self, data):
425        if self.key_type == b"ssh-rsa":
426            return self._process_ssh_rsa(data)
427        if self.key_type == b"ssh-dss":
428            return self._process_ssh_dss(data)
429        if self.key_type.strip().startswith(b"ecdsa-sha"):
430            return self._process_ecdsa_sha(data)
431        if self.key_type == b"ssh-ed25519":
432            return self._process_ed25516(data)
433        if self.key_type.strip().startswith(b"sk-ecdsa-sha"):
434            return self._process_sk_ecdsa_sha(data)
435        if self.key_type.strip().startswith(b"sk-ssh-ed25519"):
436            return self._process_sk_ed25519(data)
437        raise NotImplementedError("Invalid key type: %s" % self.key_type.decode())
438
439    def parse(self, keydata=None):
440        """Validates SSH public key.
441
442        Throws exception for invalid keys. Otherwise returns None.
443
444        Populates key_type, bits and bits fields.
445
446        For rsa keys, see field "rsa" for raw public key data.
447        For dsa keys, see field "dsa".
448        For ecdsa keys, see field "ecdsa"."""
449        if keydata is None:
450            if self.keydata is None:
451                raise ValueError("Key data must be supplied either in constructor or to parse()")
452            keydata = self.keydata
453        else:
454            self.reset()
455            self.keydata = keydata
456
457        if keydata.startswith("---- BEGIN SSH2 PUBLIC KEY ----"):
458            # SSH2 key format
459            key_type = None  # There is no redundant key-type field - skip comparing plain-text and encoded data.
460            pubkey_content = "".join([line for line in keydata.split("\n") if ":" not in line and "----" not in line])
461        else:
462            key_parts = self._split_key(keydata)
463            key_type = key_parts[0]
464            pubkey_content = key_parts[1]
465
466        self._decoded_key = self.decode_key(pubkey_content)
467
468        # Check key type
469        current_position, unpacked_key_type = self._unpack_by_int(self._decoded_key, 0)
470        if key_type is not None and key_type != unpacked_key_type.decode():
471            raise InvalidTypeError("Keytype mismatch: %s != %s" % (key_type, unpacked_key_type.decode()))
472
473        self.key_type = unpacked_key_type
474
475        key_data_length = self._process_key(self._decoded_key[current_position:])
476        current_position = current_position + key_data_length
477
478        if current_position != len(self._decoded_key):
479            raise MalformedDataError("Leftover data: %s bytes" % (len(self._decoded_key) - current_position))
480
481        if self.disallow_options and self.options:
482            raise InvalidOptionsError("Options are disallowed.")
483