1# This file is dual licensed under the terms of the Apache License, Version
2# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3# for complete details.
4
5from __future__ import absolute_import, division, print_function
6
7import binascii
8import os
9import re
10import struct
11
12import six
13
14from cryptography import utils
15from cryptography.exceptions import UnsupportedAlgorithm
16from cryptography.hazmat.backends import _get_backend
17from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, rsa
18from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
19from cryptography.hazmat.primitives.serialization import (
20    Encoding,
21    NoEncryption,
22    PrivateFormat,
23    PublicFormat,
24)
25
26try:
27    from bcrypt import kdf as _bcrypt_kdf
28
29    _bcrypt_supported = True
30except ImportError:
31    _bcrypt_supported = False
32
33    def _bcrypt_kdf(*args, **kwargs):
34        raise UnsupportedAlgorithm("Need bcrypt module")
35
36
37try:
38    from base64 import encodebytes as _base64_encode
39except ImportError:
40    from base64 import encodestring as _base64_encode
41
42_SSH_ED25519 = b"ssh-ed25519"
43_SSH_RSA = b"ssh-rsa"
44_SSH_DSA = b"ssh-dss"
45_ECDSA_NISTP256 = b"ecdsa-sha2-nistp256"
46_ECDSA_NISTP384 = b"ecdsa-sha2-nistp384"
47_ECDSA_NISTP521 = b"ecdsa-sha2-nistp521"
48_CERT_SUFFIX = b"-cert-v01@openssh.com"
49
50_SSH_PUBKEY_RC = re.compile(br"\A(\S+)[ \t]+(\S+)")
51_SK_MAGIC = b"openssh-key-v1\0"
52_SK_START = b"-----BEGIN OPENSSH PRIVATE KEY-----"
53_SK_END = b"-----END OPENSSH PRIVATE KEY-----"
54_BCRYPT = b"bcrypt"
55_NONE = b"none"
56_DEFAULT_CIPHER = b"aes256-ctr"
57_DEFAULT_ROUNDS = 16
58_MAX_PASSWORD = 72
59
60# re is only way to work on bytes-like data
61_PEM_RC = re.compile(_SK_START + b"(.*?)" + _SK_END, re.DOTALL)
62
63# padding for max blocksize
64_PADDING = memoryview(bytearray(range(1, 1 + 16)))
65
66# ciphers that are actually used in key wrapping
67_SSH_CIPHERS = {
68    b"aes256-ctr": (algorithms.AES, 32, modes.CTR, 16),
69    b"aes256-cbc": (algorithms.AES, 32, modes.CBC, 16),
70}
71
72# map local curve name to key type
73_ECDSA_KEY_TYPE = {
74    "secp256r1": _ECDSA_NISTP256,
75    "secp384r1": _ECDSA_NISTP384,
76    "secp521r1": _ECDSA_NISTP521,
77}
78
79_U32 = struct.Struct(b">I")
80_U64 = struct.Struct(b">Q")
81
82
83def _ecdsa_key_type(public_key):
84    """Return SSH key_type and curve_name for private key."""
85    curve = public_key.curve
86    if curve.name not in _ECDSA_KEY_TYPE:
87        raise ValueError(
88            "Unsupported curve for ssh private key: %r" % curve.name
89        )
90    return _ECDSA_KEY_TYPE[curve.name]
91
92
93def _ssh_pem_encode(data, prefix=_SK_START + b"\n", suffix=_SK_END + b"\n"):
94    return b"".join([prefix, _base64_encode(data), suffix])
95
96
97def _check_block_size(data, block_len):
98    """Require data to be full blocks"""
99    if not data or len(data) % block_len != 0:
100        raise ValueError("Corrupt data: missing padding")
101
102
103def _check_empty(data):
104    """All data should have been parsed."""
105    if data:
106        raise ValueError("Corrupt data: unparsed data")
107
108
109def _init_cipher(ciphername, password, salt, rounds, backend):
110    """Generate key + iv and return cipher."""
111    if not password:
112        raise ValueError("Key is password-protected.")
113
114    algo, key_len, mode, iv_len = _SSH_CIPHERS[ciphername]
115    seed = _bcrypt_kdf(password, salt, key_len + iv_len, rounds, True)
116    return Cipher(algo(seed[:key_len]), mode(seed[key_len:]), backend)
117
118
119def _get_u32(data):
120    """Uint32"""
121    if len(data) < 4:
122        raise ValueError("Invalid data")
123    return _U32.unpack(data[:4])[0], data[4:]
124
125
126def _get_u64(data):
127    """Uint64"""
128    if len(data) < 8:
129        raise ValueError("Invalid data")
130    return _U64.unpack(data[:8])[0], data[8:]
131
132
133def _get_sshstr(data):
134    """Bytes with u32 length prefix"""
135    n, data = _get_u32(data)
136    if n > len(data):
137        raise ValueError("Invalid data")
138    return data[:n], data[n:]
139
140
141def _get_mpint(data):
142    """Big integer."""
143    val, data = _get_sshstr(data)
144    if val and six.indexbytes(val, 0) > 0x7F:
145        raise ValueError("Invalid data")
146    return utils.int_from_bytes(val, "big"), data
147
148
149def _to_mpint(val):
150    """Storage format for signed bigint."""
151    if val < 0:
152        raise ValueError("negative mpint not allowed")
153    if not val:
154        return b""
155    nbytes = (val.bit_length() + 8) // 8
156    return utils.int_to_bytes(val, nbytes)
157
158
159class _FragList(object):
160    """Build recursive structure without data copy."""
161
162    def __init__(self, init=None):
163        self.flist = []
164        if init:
165            self.flist.extend(init)
166
167    def put_raw(self, val):
168        """Add plain bytes"""
169        self.flist.append(val)
170
171    def put_u32(self, val):
172        """Big-endian uint32"""
173        self.flist.append(_U32.pack(val))
174
175    def put_sshstr(self, val):
176        """Bytes prefixed with u32 length"""
177        if isinstance(val, (bytes, memoryview, bytearray)):
178            self.put_u32(len(val))
179            self.flist.append(val)
180        else:
181            self.put_u32(val.size())
182            self.flist.extend(val.flist)
183
184    def put_mpint(self, val):
185        """Big-endian bigint prefixed with u32 length"""
186        self.put_sshstr(_to_mpint(val))
187
188    def size(self):
189        """Current number of bytes"""
190        return sum(map(len, self.flist))
191
192    def render(self, dstbuf, pos=0):
193        """Write into bytearray"""
194        for frag in self.flist:
195            flen = len(frag)
196            start, pos = pos, pos + flen
197            dstbuf[start:pos] = frag
198        return pos
199
200    def tobytes(self):
201        """Return as bytes"""
202        buf = memoryview(bytearray(self.size()))
203        self.render(buf)
204        return buf.tobytes()
205
206
207class _SSHFormatRSA(object):
208    """Format for RSA keys.
209
210    Public:
211        mpint e, n
212    Private:
213        mpint n, e, d, iqmp, p, q
214    """
215
216    def get_public(self, data):
217        """RSA public fields"""
218        e, data = _get_mpint(data)
219        n, data = _get_mpint(data)
220        return (e, n), data
221
222    def load_public(self, key_type, data, backend):
223        """Make RSA public key from data."""
224        (e, n), data = self.get_public(data)
225        public_numbers = rsa.RSAPublicNumbers(e, n)
226        public_key = public_numbers.public_key(backend)
227        return public_key, data
228
229    def load_private(self, data, pubfields, backend):
230        """Make RSA private key from data."""
231        n, data = _get_mpint(data)
232        e, data = _get_mpint(data)
233        d, data = _get_mpint(data)
234        iqmp, data = _get_mpint(data)
235        p, data = _get_mpint(data)
236        q, data = _get_mpint(data)
237
238        if (e, n) != pubfields:
239            raise ValueError("Corrupt data: rsa field mismatch")
240        dmp1 = rsa.rsa_crt_dmp1(d, p)
241        dmq1 = rsa.rsa_crt_dmq1(d, q)
242        public_numbers = rsa.RSAPublicNumbers(e, n)
243        private_numbers = rsa.RSAPrivateNumbers(
244            p, q, d, dmp1, dmq1, iqmp, public_numbers
245        )
246        private_key = private_numbers.private_key(backend)
247        return private_key, data
248
249    def encode_public(self, public_key, f_pub):
250        """Write RSA public key"""
251        pubn = public_key.public_numbers()
252        f_pub.put_mpint(pubn.e)
253        f_pub.put_mpint(pubn.n)
254
255    def encode_private(self, private_key, f_priv):
256        """Write RSA private key"""
257        private_numbers = private_key.private_numbers()
258        public_numbers = private_numbers.public_numbers
259
260        f_priv.put_mpint(public_numbers.n)
261        f_priv.put_mpint(public_numbers.e)
262
263        f_priv.put_mpint(private_numbers.d)
264        f_priv.put_mpint(private_numbers.iqmp)
265        f_priv.put_mpint(private_numbers.p)
266        f_priv.put_mpint(private_numbers.q)
267
268
269class _SSHFormatDSA(object):
270    """Format for DSA keys.
271
272    Public:
273        mpint p, q, g, y
274    Private:
275        mpint p, q, g, y, x
276    """
277
278    def get_public(self, data):
279        """DSA public fields"""
280        p, data = _get_mpint(data)
281        q, data = _get_mpint(data)
282        g, data = _get_mpint(data)
283        y, data = _get_mpint(data)
284        return (p, q, g, y), data
285
286    def load_public(self, key_type, data, backend):
287        """Make DSA public key from data."""
288        (p, q, g, y), data = self.get_public(data)
289        parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
290        public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
291        self._validate(public_numbers)
292        public_key = public_numbers.public_key(backend)
293        return public_key, data
294
295    def load_private(self, data, pubfields, backend):
296        """Make DSA private key from data."""
297        (p, q, g, y), data = self.get_public(data)
298        x, data = _get_mpint(data)
299
300        if (p, q, g, y) != pubfields:
301            raise ValueError("Corrupt data: dsa field mismatch")
302        parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
303        public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)
304        self._validate(public_numbers)
305        private_numbers = dsa.DSAPrivateNumbers(x, public_numbers)
306        private_key = private_numbers.private_key(backend)
307        return private_key, data
308
309    def encode_public(self, public_key, f_pub):
310        """Write DSA public key"""
311        public_numbers = public_key.public_numbers()
312        parameter_numbers = public_numbers.parameter_numbers
313        self._validate(public_numbers)
314
315        f_pub.put_mpint(parameter_numbers.p)
316        f_pub.put_mpint(parameter_numbers.q)
317        f_pub.put_mpint(parameter_numbers.g)
318        f_pub.put_mpint(public_numbers.y)
319
320    def encode_private(self, private_key, f_priv):
321        """Write DSA private key"""
322        self.encode_public(private_key.public_key(), f_priv)
323        f_priv.put_mpint(private_key.private_numbers().x)
324
325    def _validate(self, public_numbers):
326        parameter_numbers = public_numbers.parameter_numbers
327        if parameter_numbers.p.bit_length() != 1024:
328            raise ValueError("SSH supports only 1024 bit DSA keys")
329
330
331class _SSHFormatECDSA(object):
332    """Format for ECDSA keys.
333
334    Public:
335        str curve
336        bytes point
337    Private:
338        str curve
339        bytes point
340        mpint secret
341    """
342
343    def __init__(self, ssh_curve_name, curve):
344        self.ssh_curve_name = ssh_curve_name
345        self.curve = curve
346
347    def get_public(self, data):
348        """ECDSA public fields"""
349        curve, data = _get_sshstr(data)
350        point, data = _get_sshstr(data)
351        if curve != self.ssh_curve_name:
352            raise ValueError("Curve name mismatch")
353        if six.indexbytes(point, 0) != 4:
354            raise NotImplementedError("Need uncompressed point")
355        return (curve, point), data
356
357    def load_public(self, key_type, data, backend):
358        """Make ECDSA public key from data."""
359        (curve_name, point), data = self.get_public(data)
360        public_key = ec.EllipticCurvePublicKey.from_encoded_point(
361            self.curve, point.tobytes()
362        )
363        return public_key, data
364
365    def load_private(self, data, pubfields, backend):
366        """Make ECDSA private key from data."""
367        (curve_name, point), data = self.get_public(data)
368        secret, data = _get_mpint(data)
369
370        if (curve_name, point) != pubfields:
371            raise ValueError("Corrupt data: ecdsa field mismatch")
372        private_key = ec.derive_private_key(secret, self.curve, backend)
373        return private_key, data
374
375    def encode_public(self, public_key, f_pub):
376        """Write ECDSA public key"""
377        point = public_key.public_bytes(
378            Encoding.X962, PublicFormat.UncompressedPoint
379        )
380        f_pub.put_sshstr(self.ssh_curve_name)
381        f_pub.put_sshstr(point)
382
383    def encode_private(self, private_key, f_priv):
384        """Write ECDSA private key"""
385        public_key = private_key.public_key()
386        private_numbers = private_key.private_numbers()
387
388        self.encode_public(public_key, f_priv)
389        f_priv.put_mpint(private_numbers.private_value)
390
391
392class _SSHFormatEd25519(object):
393    """Format for Ed25519 keys.
394
395    Public:
396        bytes point
397    Private:
398        bytes point
399        bytes secret_and_point
400    """
401
402    def get_public(self, data):
403        """Ed25519 public fields"""
404        point, data = _get_sshstr(data)
405        return (point,), data
406
407    def load_public(self, key_type, data, backend):
408        """Make Ed25519 public key from data."""
409        (point,), data = self.get_public(data)
410        public_key = ed25519.Ed25519PublicKey.from_public_bytes(
411            point.tobytes()
412        )
413        return public_key, data
414
415    def load_private(self, data, pubfields, backend):
416        """Make Ed25519 private key from data."""
417        (point,), data = self.get_public(data)
418        keypair, data = _get_sshstr(data)
419
420        secret = keypair[:32]
421        point2 = keypair[32:]
422        if point != point2 or (point,) != pubfields:
423            raise ValueError("Corrupt data: ed25519 field mismatch")
424        private_key = ed25519.Ed25519PrivateKey.from_private_bytes(secret)
425        return private_key, data
426
427    def encode_public(self, public_key, f_pub):
428        """Write Ed25519 public key"""
429        raw_public_key = public_key.public_bytes(
430            Encoding.Raw, PublicFormat.Raw
431        )
432        f_pub.put_sshstr(raw_public_key)
433
434    def encode_private(self, private_key, f_priv):
435        """Write Ed25519 private key"""
436        public_key = private_key.public_key()
437        raw_private_key = private_key.private_bytes(
438            Encoding.Raw, PrivateFormat.Raw, NoEncryption()
439        )
440        raw_public_key = public_key.public_bytes(
441            Encoding.Raw, PublicFormat.Raw
442        )
443        f_keypair = _FragList([raw_private_key, raw_public_key])
444
445        self.encode_public(public_key, f_priv)
446        f_priv.put_sshstr(f_keypair)
447
448
449_KEY_FORMATS = {
450    _SSH_RSA: _SSHFormatRSA(),
451    _SSH_DSA: _SSHFormatDSA(),
452    _SSH_ED25519: _SSHFormatEd25519(),
453    _ECDSA_NISTP256: _SSHFormatECDSA(b"nistp256", ec.SECP256R1()),
454    _ECDSA_NISTP384: _SSHFormatECDSA(b"nistp384", ec.SECP384R1()),
455    _ECDSA_NISTP521: _SSHFormatECDSA(b"nistp521", ec.SECP521R1()),
456}
457
458
459def _lookup_kformat(key_type):
460    """Return valid format or throw error"""
461    if not isinstance(key_type, bytes):
462        key_type = memoryview(key_type).tobytes()
463    if key_type in _KEY_FORMATS:
464        return _KEY_FORMATS[key_type]
465    raise UnsupportedAlgorithm("Unsupported key type: %r" % key_type)
466
467
468def load_ssh_private_key(data, password, backend=None):
469    """Load private key from OpenSSH custom encoding."""
470    utils._check_byteslike("data", data)
471    backend = _get_backend(backend)
472    if password is not None:
473        utils._check_bytes("password", password)
474
475    m = _PEM_RC.search(data)
476    if not m:
477        raise ValueError("Not OpenSSH private key format")
478    p1 = m.start(1)
479    p2 = m.end(1)
480    data = binascii.a2b_base64(memoryview(data)[p1:p2])
481    if not data.startswith(_SK_MAGIC):
482        raise ValueError("Not OpenSSH private key format")
483    data = memoryview(data)[len(_SK_MAGIC) :]
484
485    # parse header
486    ciphername, data = _get_sshstr(data)
487    kdfname, data = _get_sshstr(data)
488    kdfoptions, data = _get_sshstr(data)
489    nkeys, data = _get_u32(data)
490    if nkeys != 1:
491        raise ValueError("Only one key supported")
492
493    # load public key data
494    pubdata, data = _get_sshstr(data)
495    pub_key_type, pubdata = _get_sshstr(pubdata)
496    kformat = _lookup_kformat(pub_key_type)
497    pubfields, pubdata = kformat.get_public(pubdata)
498    _check_empty(pubdata)
499
500    # load secret data
501    edata, data = _get_sshstr(data)
502    _check_empty(data)
503
504    if (ciphername, kdfname) != (_NONE, _NONE):
505        ciphername = ciphername.tobytes()
506        if ciphername not in _SSH_CIPHERS:
507            raise UnsupportedAlgorithm("Unsupported cipher: %r" % ciphername)
508        if kdfname != _BCRYPT:
509            raise UnsupportedAlgorithm("Unsupported KDF: %r" % kdfname)
510        blklen = _SSH_CIPHERS[ciphername][3]
511        _check_block_size(edata, blklen)
512        salt, kbuf = _get_sshstr(kdfoptions)
513        rounds, kbuf = _get_u32(kbuf)
514        _check_empty(kbuf)
515        ciph = _init_cipher(
516            ciphername, password, salt.tobytes(), rounds, backend
517        )
518        edata = memoryview(ciph.decryptor().update(edata))
519    else:
520        blklen = 8
521        _check_block_size(edata, blklen)
522    ck1, edata = _get_u32(edata)
523    ck2, edata = _get_u32(edata)
524    if ck1 != ck2:
525        raise ValueError("Corrupt data: broken checksum")
526
527    # load per-key struct
528    key_type, edata = _get_sshstr(edata)
529    if key_type != pub_key_type:
530        raise ValueError("Corrupt data: key type mismatch")
531    private_key, edata = kformat.load_private(edata, pubfields, backend)
532    comment, edata = _get_sshstr(edata)
533
534    # yes, SSH does padding check *after* all other parsing is done.
535    # need to follow as it writes zero-byte padding too.
536    if edata != _PADDING[: len(edata)]:
537        raise ValueError("Corrupt data: invalid padding")
538
539    return private_key
540
541
542def serialize_ssh_private_key(private_key, password=None):
543    """Serialize private key with OpenSSH custom encoding."""
544    if password is not None:
545        utils._check_bytes("password", password)
546    if password and len(password) > _MAX_PASSWORD:
547        raise ValueError(
548            "Passwords longer than 72 bytes are not supported by "
549            "OpenSSH private key format"
550        )
551
552    if isinstance(private_key, ec.EllipticCurvePrivateKey):
553        key_type = _ecdsa_key_type(private_key.public_key())
554    elif isinstance(private_key, rsa.RSAPrivateKey):
555        key_type = _SSH_RSA
556    elif isinstance(private_key, dsa.DSAPrivateKey):
557        key_type = _SSH_DSA
558    elif isinstance(private_key, ed25519.Ed25519PrivateKey):
559        key_type = _SSH_ED25519
560    else:
561        raise ValueError("Unsupported key type")
562    kformat = _lookup_kformat(key_type)
563
564    # setup parameters
565    f_kdfoptions = _FragList()
566    if password:
567        ciphername = _DEFAULT_CIPHER
568        blklen = _SSH_CIPHERS[ciphername][3]
569        kdfname = _BCRYPT
570        rounds = _DEFAULT_ROUNDS
571        salt = os.urandom(16)
572        f_kdfoptions.put_sshstr(salt)
573        f_kdfoptions.put_u32(rounds)
574        backend = _get_backend(None)
575        ciph = _init_cipher(ciphername, password, salt, rounds, backend)
576    else:
577        ciphername = kdfname = _NONE
578        blklen = 8
579        ciph = None
580    nkeys = 1
581    checkval = os.urandom(4)
582    comment = b""
583
584    # encode public and private parts together
585    f_public_key = _FragList()
586    f_public_key.put_sshstr(key_type)
587    kformat.encode_public(private_key.public_key(), f_public_key)
588
589    f_secrets = _FragList([checkval, checkval])
590    f_secrets.put_sshstr(key_type)
591    kformat.encode_private(private_key, f_secrets)
592    f_secrets.put_sshstr(comment)
593    f_secrets.put_raw(_PADDING[: blklen - (f_secrets.size() % blklen)])
594
595    # top-level structure
596    f_main = _FragList()
597    f_main.put_raw(_SK_MAGIC)
598    f_main.put_sshstr(ciphername)
599    f_main.put_sshstr(kdfname)
600    f_main.put_sshstr(f_kdfoptions)
601    f_main.put_u32(nkeys)
602    f_main.put_sshstr(f_public_key)
603    f_main.put_sshstr(f_secrets)
604
605    # copy result info bytearray
606    slen = f_secrets.size()
607    mlen = f_main.size()
608    buf = memoryview(bytearray(mlen + blklen))
609    f_main.render(buf)
610    ofs = mlen - slen
611
612    # encrypt in-place
613    if ciph is not None:
614        ciph.encryptor().update_into(buf[ofs:mlen], buf[ofs:])
615
616    txt = _ssh_pem_encode(buf[:mlen])
617    buf[ofs:mlen] = bytearray(slen)
618    return txt
619
620
621def load_ssh_public_key(data, backend=None):
622    """Load public key from OpenSSH one-line format."""
623    backend = _get_backend(backend)
624    utils._check_byteslike("data", data)
625
626    m = _SSH_PUBKEY_RC.match(data)
627    if not m:
628        raise ValueError("Invalid line format")
629    key_type = orig_key_type = m.group(1)
630    key_body = m.group(2)
631    with_cert = False
632    if _CERT_SUFFIX == key_type[-len(_CERT_SUFFIX) :]:
633        with_cert = True
634        key_type = key_type[: -len(_CERT_SUFFIX)]
635    kformat = _lookup_kformat(key_type)
636
637    try:
638        data = memoryview(binascii.a2b_base64(key_body))
639    except (TypeError, binascii.Error):
640        raise ValueError("Invalid key format")
641
642    inner_key_type, data = _get_sshstr(data)
643    if inner_key_type != orig_key_type:
644        raise ValueError("Invalid key format")
645    if with_cert:
646        nonce, data = _get_sshstr(data)
647    public_key, data = kformat.load_public(key_type, data, backend)
648    if with_cert:
649        serial, data = _get_u64(data)
650        cctype, data = _get_u32(data)
651        key_id, data = _get_sshstr(data)
652        principals, data = _get_sshstr(data)
653        valid_after, data = _get_u64(data)
654        valid_before, data = _get_u64(data)
655        crit_options, data = _get_sshstr(data)
656        extensions, data = _get_sshstr(data)
657        reserved, data = _get_sshstr(data)
658        sig_key, data = _get_sshstr(data)
659        signature, data = _get_sshstr(data)
660    _check_empty(data)
661    return public_key
662
663
664def serialize_ssh_public_key(public_key):
665    """One-line public key format for OpenSSH"""
666    if isinstance(public_key, ec.EllipticCurvePublicKey):
667        key_type = _ecdsa_key_type(public_key)
668    elif isinstance(public_key, rsa.RSAPublicKey):
669        key_type = _SSH_RSA
670    elif isinstance(public_key, dsa.DSAPublicKey):
671        key_type = _SSH_DSA
672    elif isinstance(public_key, ed25519.Ed25519PublicKey):
673        key_type = _SSH_ED25519
674    else:
675        raise ValueError("Unsupported key type")
676    kformat = _lookup_kformat(key_type)
677
678    f_pub = _FragList()
679    f_pub.put_sshstr(key_type)
680    kformat.encode_public(public_key, f_pub)
681
682    pub = binascii.b2a_base64(f_pub.tobytes()).strip()
683    return b"".join([key_type, b" ", pub])
684