1""" fields.py
2"""
3from __future__ import absolute_import, division
4
5import abc
6import binascii
7import collections
8import copy
9import hashlib
10import itertools
11import math
12import os
13
14from pyasn1.codec.der import decoder
15from pyasn1.codec.der import encoder
16from pyasn1.type.univ import Integer
17from pyasn1.type.univ import Sequence
18from pyasn1.type.namedtype import NamedTypes, NamedType
19
20from cryptography.exceptions import InvalidSignature
21
22from cryptography.hazmat.backends import default_backend
23
24from cryptography.hazmat.primitives import hashes
25from cryptography.hazmat.primitives import serialization
26
27from cryptography.hazmat.primitives.asymmetric import rsa
28from cryptography.hazmat.primitives.asymmetric import dsa
29from cryptography.hazmat.primitives.asymmetric import ec
30from cryptography.hazmat.primitives.asymmetric import ed25519
31from cryptography.hazmat.primitives.asymmetric import padding
32from cryptography.hazmat.primitives.asymmetric import x25519
33
34from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash
35
36from cryptography.hazmat.primitives.keywrap import aes_key_wrap
37from cryptography.hazmat.primitives.keywrap import aes_key_unwrap
38
39from cryptography.hazmat.primitives.padding import PKCS7
40
41from .subpackets import Signature as SignatureSP
42from .subpackets import UserAttribute
43from .subpackets import signature
44from .subpackets import userattribute
45
46from .types import MPI
47from .types import MPIs
48
49from ..constants import EllipticCurveOID
50from ..constants import ECPointFormat
51from ..constants import HashAlgorithm
52from ..constants import PubKeyAlgorithm
53from ..constants import String2KeyType
54from ..constants import S2KGNUExtension
55from ..constants import SymmetricKeyAlgorithm
56
57from ..decorators import sdproperty
58
59from ..errors import PGPDecryptionError
60from ..errors import PGPError
61from ..errors import PGPIncompatibleECPointFormat
62
63from ..symenc import _decrypt
64from ..symenc import _encrypt
65
66from ..types import Field
67
68__all__ = ['SubPackets',
69           'UserAttributeSubPackets',
70           'Signature',
71           'OpaqueSignature',
72           'RSASignature',
73           'DSASignature',
74           'ECDSASignature',
75           'EdDSASignature',
76           'PubKey',
77           'OpaquePubKey',
78           'RSAPub',
79           'DSAPub',
80           'ElGPub',
81           'ECPoint',
82           'ECDSAPub',
83           'EdDSAPub',
84           'ECDHPub',
85           'String2Key',
86           'ECKDF',
87           'PrivKey',
88           'OpaquePrivKey',
89           'RSAPriv',
90           'DSAPriv',
91           'ElGPriv',
92           'ECDSAPriv',
93           'EdDSAPriv',
94           'ECDHPriv',
95           'CipherText',
96           'RSACipherText',
97           'ElGCipherText',
98           'ECDHCipherText', ]
99
100
101class SubPackets(collections.MutableMapping, Field):
102    _spmodule = signature
103
104    def __init__(self):
105        super(SubPackets, self).__init__()
106        self._hashed_sp = collections.OrderedDict()
107        self._unhashed_sp = collections.OrderedDict()
108
109    def __bytearray__(self):
110        _bytes = bytearray()
111        _bytes += self.__hashbytearray__()
112        _bytes += self.__unhashbytearray__()
113        return _bytes
114
115    def __hashbytearray__(self):
116        _bytes = bytearray()
117        _bytes += self.int_to_bytes(sum(len(sp) for sp in self._hashed_sp.values()), 2)
118        for hsp in self._hashed_sp.values():
119            _bytes += hsp.__bytearray__()
120        return _bytes
121
122    def __unhashbytearray__(self):
123        _bytes = bytearray()
124        _bytes += self.int_to_bytes(sum(len(sp) for sp in self._unhashed_sp.values()), 2)
125        for uhsp in self._unhashed_sp.values():
126            _bytes += uhsp.__bytearray__()
127        return _bytes
128
129    def __len__(self):  # pragma: no cover
130        return sum(sp.header.length for sp in itertools.chain(self._hashed_sp.values(), self._unhashed_sp.values())) + 4
131
132    def __iter__(self):
133        for sp in itertools.chain(self._hashed_sp.values(), self._unhashed_sp.values()):
134            yield sp
135
136    def __setitem__(self, key, val):
137        # the key provided should always be the classname for the subpacket
138        # but, there can be multiple subpackets of the same type
139        # so, it should be stored in the format: [h_]<key>_<seqid>
140        # where:
141        #  - <key> is the classname of val
142        #  - <seqid> is a sequence id, starting at 0, for a given classname
143
144        i = 0
145        if isinstance(key, tuple):  # pragma: no cover
146            key, i = key
147
148        d = self._unhashed_sp
149        if key.startswith('h_'):
150            d, key = self._hashed_sp, key[2:]
151
152        while (key, i) in d:
153            i += 1
154
155        d[(key, i)] = val
156
157    def __getitem__(self, key):
158        if isinstance(key, tuple):  # pragma: no cover
159            return self._hashed_sp.get(key, self._unhashed_sp.get(key))
160
161        if key.startswith('h_'):
162            return [v for k, v in self._hashed_sp.items() if key[2:] == k[0]]
163
164        else:
165            return [v for k, v in itertools.chain(self._hashed_sp.items(), self._unhashed_sp.items()) if key == k[0]]
166
167    def __delitem__(self, key):
168        ##TODO: this
169        raise NotImplementedError
170
171    def __contains__(self, key):
172        return key in set(k for k, _ in itertools.chain(self._hashed_sp, self._unhashed_sp))
173
174    def __copy__(self):
175        sp = SubPackets()
176        sp._hashed_sp = self._hashed_sp.copy()
177        sp._unhashed_sp = self._unhashed_sp.copy()
178
179        return sp
180
181    def addnew(self, spname, hashed=False, **kwargs):
182        nsp = getattr(self._spmodule, spname)()
183        for p, v in kwargs.items():
184            if hasattr(nsp, p):
185                setattr(nsp, p, v)
186        nsp.update_hlen()
187        if hashed:
188            self['h_' + spname] = nsp
189
190        else:
191            self[spname] = nsp
192
193    def update_hlen(self):
194        for sp in self:
195            sp.update_hlen()
196
197    def parse(self, packet):
198        hl = self.bytes_to_int(packet[:2])
199        del packet[:2]
200
201        # we do it this way because we can't ensure that subpacket headers are sized appropriately
202        # for their contents, but we can at least output that correctly
203        # so instead of tracking how many bytes we can now output, we track how many bytes have we parsed so far
204        plen = len(packet)
205        while plen - len(packet) < hl:
206            sp = SignatureSP(packet)
207            self['h_' + sp.__class__.__name__] = sp
208
209        uhl = self.bytes_to_int(packet[:2])
210        del packet[:2]
211
212        plen = len(packet)
213        while plen - len(packet) < uhl:
214            sp = SignatureSP(packet)
215            self[sp.__class__.__name__] = sp
216
217
218class UserAttributeSubPackets(SubPackets):
219    """
220    This is nearly the same as just the unhashed subpackets from above,
221    except that there isn't a length specifier. So, parse will only parse one packet,
222    appending that one packet to self.__unhashed_sp.
223    """
224    _spmodule = userattribute
225
226    def __bytearray__(self):
227        _bytes = bytearray()
228        for uhsp in self._unhashed_sp.values():
229            _bytes += uhsp.__bytearray__()
230        return _bytes
231
232    def __len__(self):  # pragma: no cover
233        return sum(len(sp) for sp in self._unhashed_sp.values())
234
235    def parse(self, packet):
236        # parse just one packet and add it to the unhashed subpacket ordereddict
237        # I actually have yet to come across a User Attribute packet with more than one subpacket
238        # which makes sense, given that there is only one defined subpacket
239        sp = UserAttribute(packet)
240        self[sp.__class__.__name__] = sp
241
242
243class Signature(MPIs):
244    def __init__(self):
245        for i in self.__mpis__:
246            setattr(self, i, MPI(0))
247
248    def __bytearray__(self):
249        _bytes = bytearray()
250        for i in self:
251            _bytes += i.to_mpibytes()
252        return _bytes
253
254    @abc.abstractproperty
255    def __sig__(self):
256        """return the signature bytes in a format that can be understood by the signature verifier"""
257
258    @abc.abstractmethod
259    def from_signer(self, sig):
260        """create and parse a concrete Signature class instance"""
261
262
263class OpaqueSignature(Signature):
264    def __init__(self):
265        super(OpaqueSignature, self).__init__()
266        self.data = bytearray()
267
268    def __bytearray__(self):
269        return self.data
270
271    def __sig__(self):
272        return self.data
273
274    def parse(self, packet):
275        self.data = packet
276
277    def from_signer(self, sig):
278        self.data = bytearray(sig)
279
280
281class RSASignature(Signature):
282    __mpis__ = ('md_mod_n', )
283
284    def __sig__(self):
285        return self.md_mod_n.to_mpibytes()[2:]
286
287    def parse(self, packet):
288        self.md_mod_n = MPI(packet)
289
290    def from_signer(self, sig):
291        self.md_mod_n = MPI(self.bytes_to_int(sig))
292
293
294class DSASignature(Signature):
295    __mpis__ = ('r', 's')
296
297    def __sig__(self):
298        # return the signature data into an ASN.1 sequence of integers in DER format
299        seq = Sequence(componentType=NamedTypes(*[NamedType(n, Integer()) for n in self.__mpis__]))
300        for n in self.__mpis__:
301            seq.setComponentByName(n, getattr(self, n))
302
303        return encoder.encode(seq)
304
305    def from_signer(self, sig):
306        ##TODO: just use pyasn1 for this
307        def _der_intf(_asn):
308            if _asn[0] != 0x02:  # pragma: no cover
309                raise ValueError("Expected: Integer (0x02). Got: 0x{:02X}".format(_asn[0]))
310            del _asn[0]
311
312            if _asn[0] & 0x80:  # pragma: no cover
313                llen = _asn[0] & 0x7F
314                del _asn[0]
315
316                flen = self.bytes_to_int(_asn[:llen])
317                del _asn[:llen]
318
319            else:
320                flen = _asn[0] & 0x7F
321                del _asn[0]
322
323            i = self.bytes_to_int(_asn[:flen])
324            del _asn[:flen]
325            return i
326
327        if isinstance(sig, bytes):
328            sig = bytearray(sig)
329
330        # this is a very limited asn1 decoder - it is only intended to decode a DER encoded sequence of integers
331        if not sig[0] == 0x30:
332            raise NotImplementedError("Expected: Sequence (0x30). Got: 0x{:02X}".format(sig[0]))
333        del sig[0]
334
335        # skip the sequence length field
336        if sig[0] & 0x80:  # pragma: no cover
337            llen = sig[0] & 0x7F
338            del sig[:llen + 1]
339
340        else:
341            del sig[0]
342
343        self.r = MPI(_der_intf(sig))
344        self.s = MPI(_der_intf(sig))
345
346    def parse(self, packet):
347        self.r = MPI(packet)
348        self.s = MPI(packet)
349
350
351class ECDSASignature(DSASignature):
352    def from_signer(self, sig):
353        seq, _ = decoder.decode(sig)
354        self.r = MPI(seq[0])
355        self.s = MPI(seq[1])
356
357
358class EdDSASignature(DSASignature):
359    def from_signer(self, sig):
360        lsig = len(sig)
361        if lsig % 2 != 0:
362            raise PGPError("malformed EdDSA signature")
363        split = lsig // 2
364        self.r = MPI(self.bytes_to_int(sig[:split]))
365        self.s = MPI(self.bytes_to_int(sig[split:]))
366
367    def __sig__(self):
368        # TODO: change this length when EdDSA can be used with another curve (Ed448)
369        l = (EllipticCurveOID.Ed25519.key_size + 7) // 8
370        return self.int_to_bytes(self.r, l) + self.int_to_bytes(self.s, l)
371
372
373class PubKey(MPIs):
374    __pubfields__ = ()
375
376    @property
377    def __mpis__(self):
378        for i in self.__pubfields__:
379            yield i
380
381    def __init__(self):
382        super(PubKey, self).__init__()
383        for field in self.__pubfields__:
384            if isinstance(field, tuple):  # pragma: no cover
385                field, val = field
386            else:
387                val = MPI(0)
388            setattr(self, field, val)
389
390    @abc.abstractmethod
391    def __pubkey__(self):
392        """return the requisite *PublicKey class from the cryptography library"""
393
394    def __len__(self):
395        return sum(len(getattr(self, i)) for i in self.__pubfields__)
396
397    def __bytearray__(self):
398        _bytes = bytearray()
399        for field in self.__pubfields__:
400            _bytes += getattr(self, field).to_mpibytes()
401
402        return _bytes
403
404    def publen(self):
405        return len(self)
406
407    def verify(self, subj, sigbytes, hash_alg):
408        return NotImplemented  # pragma: no cover
409
410
411class OpaquePubKey(PubKey):  # pragma: no cover
412    def __init__(self):
413        super(OpaquePubKey, self).__init__()
414        self.data = bytearray()
415
416    def __iter__(self):
417        yield self.data
418
419    def __pubkey__(self):
420        return NotImplemented
421
422    def __bytearray__(self):
423        return self.data
424
425    def parse(self, packet):
426        ##TODO: this needs to be length-bounded to the end of the packet
427        self.data = packet
428
429
430class RSAPub(PubKey):
431    __pubfields__ = ('n', 'e')
432
433    def __pubkey__(self):
434        return rsa.RSAPublicNumbers(self.e, self.n).public_key(default_backend())
435
436    def verify(self, subj, sigbytes, hash_alg):
437        # zero-pad sigbytes if necessary
438        sigbytes = (b'\x00' * (self.n.byte_length() - len(sigbytes))) + sigbytes
439        try:
440            self.__pubkey__().verify(sigbytes, subj, padding.PKCS1v15(), hash_alg)
441        except InvalidSignature:
442            return False
443        return True
444
445    def parse(self, packet):
446        self.n = MPI(packet)
447        self.e = MPI(packet)
448
449
450class DSAPub(PubKey):
451    __pubfields__ = ('p', 'q', 'g', 'y')
452
453    def __pubkey__(self):
454        params = dsa.DSAParameterNumbers(self.p, self.q, self.g)
455        return dsa.DSAPublicNumbers(self.y, params).public_key(default_backend())
456
457    def verify(self, subj, sigbytes, hash_alg):
458        try:
459            self.__pubkey__().verify(sigbytes, subj, hash_alg)
460        except InvalidSignature:
461            return False
462        return True
463
464    def parse(self, packet):
465        self.p = MPI(packet)
466        self.q = MPI(packet)
467        self.g = MPI(packet)
468        self.y = MPI(packet)
469
470
471class ElGPub(PubKey):
472    __pubfields__ = ('p', 'g', 'y')
473
474    def __pubkey__(self):
475        raise NotImplementedError()
476
477    def parse(self, packet):
478        self.p = MPI(packet)
479        self.g = MPI(packet)
480        self.y = MPI(packet)
481
482
483class ECPoint:
484    def __init__(self, packet=None):
485        if packet is None:
486            return
487        xy = bytearray(MPI(packet).to_mpibytes()[2:])
488        self.format = ECPointFormat(xy[0])
489        del xy[0]
490        if self.format == ECPointFormat.Standard:
491            xylen = len(xy)
492            if xylen % 2 != 0:
493                raise PGPError("malformed EC point")
494            self.bytelen = xylen // 2
495            self.x = MPI(MPIs.bytes_to_int(xy[:self.bytelen]))
496            self.y = MPI(MPIs.bytes_to_int(xy[self.bytelen:]))
497        elif self.format == ECPointFormat.Native:
498            self.bytelen = 0 # dummy value for copy
499            self.x = bytes(xy)
500            self.y = None
501        else:
502            raise NotImplementedError("No curve is supposed to use only X or Y coordinates")
503
504    @classmethod
505    def from_values(cls, bitlen, pform, x, y=None):
506        ct = cls()
507        ct.bytelen = (bitlen + 7) // 8
508        ct.format = pform
509        ct.x = x
510        ct.y = y
511        return ct
512
513    def __len__(self):
514        """ Returns length of MPI encoded point """
515        if self.format == ECPointFormat.Standard:
516            return 2 * self.bytelen + 3
517        elif self.format == ECPointFormat.Native:
518            return len(self.x) + 3
519        else:
520            raise NotImplementedError("No curve is supposed to use only X or Y coordinates")
521
522    def to_mpibytes(self):
523        """ Returns MPI encoded point as it should be written in packet """
524        b = bytearray()
525        b.append(self.format)
526        if self.format == ECPointFormat.Standard:
527            b += MPIs.int_to_bytes(self.x, self.bytelen)
528            b += MPIs.int_to_bytes(self.y, self.bytelen)
529        elif self.format == ECPointFormat.Native:
530            b += self.x
531        else:
532            raise NotImplementedError("No curve is supposed to use only X or Y coordinates")
533        return MPI(MPIs.bytes_to_int(b)).to_mpibytes()
534
535    def __bytearray__(self):
536        return self.to_mpibytes()
537
538    def __copy__(self):
539        pk = self.__class__()
540        pk.bytelen = self.bytelen
541        pk.format = self.format
542        pk.x = copy.copy(self.x)
543        pk.y = copy.copy(self.y)
544        return pk
545
546
547class ECDSAPub(PubKey):
548    __pubfields__ = ('p',)
549
550    def __init__(self):
551        super(ECDSAPub, self).__init__()
552        self.oid = None
553
554    def __len__(self):
555        return len(self.p) + len(encoder.encode(self.oid.value)) - 1
556
557    def __pubkey__(self):
558        return ec.EllipticCurvePublicNumbers(self.p.x, self.p.y, self.oid.curve()).public_key(default_backend())
559
560    def __bytearray__(self):
561        _b = bytearray()
562        _b += encoder.encode(self.oid.value)[1:]
563        _b += self.p.to_mpibytes()
564        return _b
565
566    def __copy__(self):
567        pkt = super(ECDSAPub, self).__copy__()
568        pkt.oid = self.oid
569        return pkt
570
571    def verify(self, subj, sigbytes, hash_alg):
572        try:
573            self.__pubkey__().verify(sigbytes, subj, ec.ECDSA(hash_alg))
574        except InvalidSignature:
575            return False
576        return True
577
578    def parse(self, packet):
579        oidlen = packet[0]
580        del packet[0]
581        _oid = bytearray(b'\x06')
582        _oid.append(oidlen)
583        _oid += bytearray(packet[:oidlen])
584        oid, _  = decoder.decode(bytes(_oid))
585        self.oid = EllipticCurveOID(oid)
586        del packet[:oidlen]
587
588        self.p = ECPoint(packet)
589        if self.p.format != ECPointFormat.Standard:
590            raise PGPIncompatibleECPointFormat("Only Standard format is valid for ECDSA")
591
592
593class EdDSAPub(PubKey):
594    __pubfields__ = ('p', )
595
596    def __init__(self):
597        super(EdDSAPub, self).__init__()
598        self.oid = None
599
600    def __len__(self):
601        return len(self.p) + len(encoder.encode(self.oid.value)) - 1
602
603    def __bytearray__(self):
604        _b = bytearray()
605        _b += encoder.encode(self.oid.value)[1:]
606        _b += self.p.to_mpibytes()
607        return _b
608
609    def __pubkey__(self):
610        return ed25519.Ed25519PublicKey.from_public_bytes(self.p.x)
611
612    def __copy__(self):
613        pkt = super(EdDSAPub, self).__copy__()
614        pkt.oid = self.oid
615        return pkt
616
617    def verify(self, subj, sigbytes, hash_alg):
618        # GnuPG requires a pre-hashing with EdDSA
619        # https://tools.ietf.org/html/draft-ietf-openpgp-rfc4880bis-06#section-14.8
620        digest = hashes.Hash(hash_alg, backend=default_backend())
621        digest.update(subj)
622        subj = digest.finalize()
623        try:
624            self.__pubkey__().verify(sigbytes, subj)
625        except InvalidSignature:
626            return False
627        return True
628
629    def parse(self, packet):
630        oidlen = packet[0]
631        del packet[0]
632        _oid = bytearray(b'\x06')
633        _oid.append(oidlen)
634        _oid += bytearray(packet[:oidlen])
635        oid, _  = decoder.decode(bytes(_oid))
636        self.oid = EllipticCurveOID(oid)
637        del packet[:oidlen]
638
639        self.p = ECPoint(packet)
640        if self.p.format != ECPointFormat.Native:
641            raise PGPIncompatibleECPointFormat("Only Native format is valid for EdDSA")
642
643
644class ECDHPub(PubKey):
645    __pubfields__ = ('p',)
646
647    def __init__(self):
648        super(ECDHPub, self).__init__()
649        self.oid = None
650        self.kdf = ECKDF()
651
652    def __len__(self):
653        return len(self.p) + len(self.kdf) + len(encoder.encode(self.oid.value)) - 1
654
655    def __pubkey__(self):
656        if self.oid == EllipticCurveOID.Curve25519:
657            return x25519.X25519PublicKey.from_public_bytes(self.p.x)
658        else:
659            return ec.EllipticCurvePublicNumbers(self.p.x, self.p.y, self.oid.curve()).public_key(default_backend())
660
661    def __bytearray__(self):
662        _b = bytearray()
663        _b += encoder.encode(self.oid.value)[1:]
664        _b += self.p.to_mpibytes()
665        _b += self.kdf.__bytearray__()
666        return _b
667
668    def __copy__(self):
669        pkt = super(ECDHPub, self).__copy__()
670        pkt.oid = self.oid
671        pkt.kdf = copy.copy(self.kdf)
672        return pkt
673
674    def parse(self, packet):
675        """
676        Algorithm-Specific Fields for ECDH keys:
677
678          o  a variable-length field containing a curve OID, formatted
679             as follows:
680
681             -  a one-octet size of the following field; values 0 and
682                0xFF are reserved for future extensions
683
684             -  the octets representing a curve OID, defined in
685                Section 11
686
687             -  MPI of an EC point representing a public key
688
689          o  a variable-length field containing KDF parameters,
690             formatted as follows:
691
692             -  a one-octet size of the following fields; values 0 and
693                0xff are reserved for future extensions
694
695             -  a one-octet value 01, reserved for future extensions
696
697             -  a one-octet hash function ID used with a KDF
698
699             -  a one-octet algorithm ID for the symmetric algorithm
700                used to wrap the symmetric key used for the message
701                encryption; see Section 8 for details
702        """
703        oidlen = packet[0]
704        del packet[0]
705        _oid = bytearray(b'\x06')
706        _oid.append(oidlen)
707        _oid += bytearray(packet[:oidlen])
708        oid, _  = decoder.decode(bytes(_oid))
709
710        self.oid = EllipticCurveOID(oid)
711        del packet[:oidlen]
712
713        self.p = ECPoint(packet)
714        if self.oid == EllipticCurveOID.Curve25519:
715            if self.p.format != ECPointFormat.Native:
716                raise PGPIncompatibleECPointFormat("Only Native format is valid for Curve25519")
717        elif self.p.format != ECPointFormat.Standard:
718            raise PGPIncompatibleECPointFormat("Only Standard format is valid for this curve")
719        self.kdf.parse(packet)
720
721
722class String2Key(Field):
723    """
724    3.7.  String-to-Key (S2K) Specifiers
725
726    String-to-key (S2K) specifiers are used to convert passphrase strings
727    into symmetric-key encryption/decryption keys.  They are used in two
728    places, currently: to encrypt the secret part of private keys in the
729    private keyring, and to convert passphrases to encryption keys for
730    symmetrically encrypted messages.
731
732    3.7.1.  String-to-Key (S2K) Specifier Types
733
734    There are three types of S2K specifiers currently supported, and
735    some reserved values:
736
737       ID          S2K Type
738       --          --------
739       0           Simple S2K
740       1           Salted S2K
741       2           Reserved value
742       3           Iterated and Salted S2K
743       100 to 110  Private/Experimental S2K
744
745    These are described in Sections 3.7.1.1 - 3.7.1.3.
746
747    3.7.1.1.  Simple S2K
748
749    This directly hashes the string to produce the key data.  See below
750    for how this hashing is done.
751
752       Octet 0:        0x00
753       Octet 1:        hash algorithm
754
755    Simple S2K hashes the passphrase to produce the session key.  The
756    manner in which this is done depends on the size of the session key
757    (which will depend on the cipher used) and the size of the hash
758    algorithm's output.  If the hash size is greater than the session key
759    size, the high-order (leftmost) octets of the hash are used as the
760    key.
761
762    If the hash size is less than the key size, multiple instances of the
763    hash context are created -- enough to produce the required key data.
764    These instances are preloaded with 0, 1, 2, ... octets of zeros (that
765    is to say, the first instance has no preloading, the second gets
766    preloaded with 1 octet of zero, the third is preloaded with two
767    octets of zeros, and so forth).
768
769    As the data is hashed, it is given independently to each hash
770    context.  Since the contexts have been initialized differently, they
771    will each produce different hash output.  Once the passphrase is
772    hashed, the output data from the multiple hashes is concatenated,
773    first hash leftmost, to produce the key data, with any excess octets
774    on the right discarded.
775
776    3.7.1.2.  Salted S2K
777
778    This includes a "salt" value in the S2K specifier -- some arbitrary
779    data -- that gets hashed along with the passphrase string, to help
780    prevent dictionary attacks.
781
782       Octet 0:        0x01
783       Octet 1:        hash algorithm
784       Octets 2-9:     8-octet salt value
785
786    Salted S2K is exactly like Simple S2K, except that the input to the
787    hash function(s) consists of the 8 octets of salt from the S2K
788    specifier, followed by the passphrase.
789
790    3.7.1.3.  Iterated and Salted S2K
791
792    This includes both a salt and an octet count.  The salt is combined
793    with the passphrase and the resulting value is hashed repeatedly.
794    This further increases the amount of work an attacker must do to try
795    dictionary attacks.
796
797       Octet  0:        0x03
798       Octet  1:        hash algorithm
799       Octets 2-9:      8-octet salt value
800       Octet  10:       count, a one-octet, coded value
801
802    The count is coded into a one-octet number using the following
803    formula:
804
805       #define EXPBIAS 6
806           count = ((Int32)16 + (c & 15)) << ((c >> 4) + EXPBIAS);
807
808    The above formula is in C, where "Int32" is a type for a 32-bit
809    integer, and the variable "c" is the coded count, Octet 10.
810
811    Iterated-Salted S2K hashes the passphrase and salt data multiple
812    times.  The total number of octets to be hashed is specified in the
813    encoded count in the S2K specifier.  Note that the resulting count
814    value is an octet count of how many octets will be hashed, not an
815    iteration count.
816
817    Initially, one or more hash contexts are set up as with the other S2K
818    algorithms, depending on how many octets of key data are needed.
819    Then the salt, followed by the passphrase data, is repeatedly hashed
820    until the number of octets specified by the octet count has been
821    hashed.  The one exception is that if the octet count is less than
822    the size of the salt plus passphrase, the full salt plus passphrase
823    will be hashed even though that is greater than the octet count.
824    After the hashing is done, the data is unloaded from the hash
825    context(s) as with the other S2K algorithms.
826    """
827    @sdproperty
828    def encalg(self):
829        return self._encalg
830
831    @encalg.register(int)
832    @encalg.register(SymmetricKeyAlgorithm)
833    def encalg_int(self, val):
834        self._encalg = SymmetricKeyAlgorithm(val)
835
836    @sdproperty
837    def specifier(self):
838        return self._specifier
839
840    @specifier.register(int)
841    @specifier.register(String2KeyType)
842    def specifier_int(self, val):
843        self._specifier = String2KeyType(val)
844
845    @sdproperty
846    def gnuext(self):
847        return self._gnuext
848
849    @gnuext.register(int)
850    @gnuext.register(S2KGNUExtension)
851    def gnuext_int(self, val):
852        self._gnuext = S2KGNUExtension(val)
853
854    @sdproperty
855    def halg(self):
856        return self._halg
857
858    @halg.register(int)
859    @halg.register(HashAlgorithm)
860    def halg_int(self, val):
861        self._halg = HashAlgorithm(val)
862
863    @sdproperty
864    def count(self):
865        return (16 + (self._count & 15)) << ((self._count >> 4) + 6)
866
867    @count.register(int)
868    def count_int(self, val):
869        if val < 0 or val > 255:  # pragma: no cover
870            raise ValueError("count must be between 0 and 256")
871        self._count = val
872
873    def __init__(self):
874        super(String2Key, self).__init__()
875        self.usage = 0
876        self.encalg = 0
877        self.specifier = 0
878        self.iv = None
879
880        # specifier-specific fields
881        # simple, salted, iterated
882        self.halg = 0
883
884        # salted, iterated
885        self.salt = bytearray()
886
887        # iterated
888        self.count = 0
889
890        # GNU extension default type: ignored if specifier != GNUExtension
891        self.gnuext = 1
892
893        # GNU extension smartcard
894        self.scserial = None
895
896    def __bytearray__(self):
897        _bytes = bytearray()
898        _bytes.append(self.usage)
899        if bool(self):
900            _bytes.append(self.encalg)
901            _bytes.append(self.specifier)
902            if self.specifier == String2KeyType.GNUExtension:
903                return self._experimental_bytearray(_bytes)
904            if self.specifier >= String2KeyType.Simple:
905                _bytes.append(self.halg)
906            if self.specifier >= String2KeyType.Salted:
907                _bytes += self.salt
908            if self.specifier == String2KeyType.Iterated:
909                _bytes.append(self._count)
910            if self.iv is not None:
911                _bytes += self.iv
912        return _bytes
913
914    def _experimental_bytearray(self, _bytes):
915        if self.specifier == String2KeyType.GNUExtension:
916            _bytes += b'\x00GNU'
917            _bytes.append(self.gnuext)
918            if self.scserial:
919                _bytes.append(len(self.scserial))
920                _bytes += self.scserial
921        return _bytes
922
923    def __len__(self):
924        return len(self.__bytearray__())
925
926    def __bool__(self):
927        return self.usage in [254, 255]
928
929    def __nonzero__(self):
930        return self.__bool__()
931
932    def __copy__(self):
933        s2k = String2Key()
934        s2k.usage = self.usage
935        s2k.encalg = self.encalg
936        s2k.specifier = self.specifier
937        s2k.gnuext = self.gnuext
938        s2k.iv = self.iv
939        s2k.halg = self.halg
940        s2k.salt = copy.copy(self.salt)
941        s2k.count = self._count
942        s2k.scserial = self.scserial
943        return s2k
944
945    def parse(self, packet, iv=True):
946        self.usage = packet[0]
947        del packet[0]
948
949        if bool(self):
950            self.encalg = packet[0]
951            del packet[0]
952
953            self.specifier = packet[0]
954            del packet[0]
955
956            if self.specifier == String2KeyType.GNUExtension:
957                return self._experimental_parse(packet, iv)
958
959            if self.specifier >= String2KeyType.Simple:
960                # this will always be true
961                self.halg = packet[0]
962                del packet[0]
963
964            if self.specifier >= String2KeyType.Salted:
965                self.salt = packet[:8]
966                del packet[:8]
967
968            if self.specifier == String2KeyType.Iterated:
969                self.count = packet[0]
970                del packet[0]
971
972            if iv:
973                self.iv = packet[:(self.encalg.block_size // 8)]
974                del packet[:(self.encalg.block_size // 8)]
975
976    def _experimental_parse(self, packet, iv=True):
977        """
978        https://git.gnupg.org/cgi-bin/gitweb.cgi?p=gnupg.git;a=blob;f=doc/DETAILS;h=3046523da62c576cf6a765a8b0829876cfdc6b3b;hb=b0f0791e4ade845b2a0e2a94dbda4f3bf1ceb039#l1346
979
980        GNU extensions to the S2K algorithm
981
982        1 octet  - S2K Usage: either 254 or 255.
983        1 octet  - S2K Cipher Algo: 0
984        1 octet  - S2K Specifier: 101
985        4 octets - "\x00GNU"
986        1 octet  - GNU S2K Extension Number.
987
988        If such a GNU extension is used neither an IV nor any kind of
989        checksum is used.  The defined GNU S2K Extension Numbers are:
990
991        - 1 :: Do not store the secret part at all.  No specific data
992               follows.
993
994        - 2 :: A stub to access smartcards.  This data follows:
995               - One octet with the length of the following serial number.
996               - The serial number. Regardless of what the length octet
997                 indicates no more than 16 octets are stored.
998        """
999        if self.specifier == String2KeyType.GNUExtension:
1000            if packet[:4] != b'\x00GNU':
1001                raise PGPError("Invalid S2K GNU extension magic value")
1002            del packet[:4]
1003            self.gnuext = packet[0]
1004            del packet[0]
1005
1006            if self.gnuext == S2KGNUExtension.Smartcard:
1007                slen = min(packet[0], 16)
1008                del packet[0]
1009                self.scserial = packet[:slen]
1010                del packet[:slen]
1011
1012    def derive_key(self, passphrase):
1013        ##TODO: raise an exception if self.usage is not 254 or 255
1014        keylen = self.encalg.key_size
1015        hashlen = self.halg.digest_size * 8
1016
1017        ctx = int(math.ceil((keylen / hashlen)))
1018
1019        # Simple S2K - always done
1020        hsalt = b''
1021        ##TODO: we could accept a passphrase that is optionally already `bytes`
1022        hpass = passphrase.encode('utf-8')
1023
1024        # salted, iterated S2K
1025        if self.specifier >= String2KeyType.Salted:
1026            hsalt = bytes(self.salt)
1027
1028        count = len(hsalt + hpass)
1029        if self.specifier == String2KeyType.Iterated and self.count > len(hsalt + hpass):
1030            count = self.count
1031
1032        hcount = (count // len(hsalt + hpass))
1033        hleft = count - (hcount * len(hsalt + hpass))
1034
1035        hashdata = ((hsalt + hpass) * hcount) + (hsalt + hpass)[:hleft]
1036
1037        h = []
1038        for i in range(0, ctx):
1039            _h = self.halg.hasher
1040            _h.update(b'\x00' * i)
1041            _h.update(hashdata)
1042            h.append(_h)
1043
1044        # GC some stuff
1045        del hsalt
1046        del hpass
1047        del hashdata
1048
1049        # and return the key!
1050        return b''.join(hc.digest() for hc in h)[:(keylen // 8)]
1051
1052
1053class ECKDF(Field):
1054    """
1055    o  a variable-length field containing KDF parameters,
1056       formatted as follows:
1057
1058       -  a one-octet size of the following fields; values 0 and
1059          0xff are reserved for future extensions
1060
1061       -  a one-octet value 01, reserved for future extensions
1062
1063       -  a one-octet hash function ID used with a KDF
1064
1065       -  a one-octet algorithm ID for the symmetric algorithm
1066          used to wrap the symmetric key used for the message
1067          encryption; see Section 8 for details
1068    """
1069    @sdproperty
1070    def halg(self):
1071        return self._halg
1072
1073    @halg.register(int)
1074    @halg.register(HashAlgorithm)
1075    def halg_int(self, val):
1076        self._halg = HashAlgorithm(val)
1077
1078    @sdproperty
1079    def encalg(self):
1080        return self._encalg
1081
1082    @encalg.register(int)
1083    @encalg.register(SymmetricKeyAlgorithm)
1084    def encalg_int(self, val):
1085        self._encalg = SymmetricKeyAlgorithm(val)
1086
1087    def __init__(self):
1088        super(ECKDF, self).__init__()
1089        self.halg = 0
1090        self.encalg = 0
1091
1092    def __bytearray__(self):
1093        _bytes = bytearray()
1094        _bytes.append(len(self) - 1)
1095        _bytes.append(0x01)
1096        _bytes.append(self.halg)
1097        _bytes.append(self.encalg)
1098        return _bytes
1099
1100    def __len__(self):
1101        return 4
1102
1103    def parse(self, packet):
1104        # packet[0] should always be 3
1105        # packet[1] should always be 1
1106        # TODO: this assert is likely not necessary, but we should raise some kind of exception
1107        #       if parsing fails due to these fields being incorrect
1108        assert packet[:2] == b'\x03\x01'
1109        del packet[:2]
1110
1111        self.halg = packet[0]
1112        del packet[0]
1113
1114        self.encalg = packet[0]
1115        del packet[0]
1116
1117    def derive_key(self, s, curve, pkalg, fingerprint):
1118        # wrapper around the Concatenation KDF method provided by cryptography
1119        # assemble the additional data as defined in RFC 6637:
1120        #  Param = curve_OID_len || curve_OID || public_key_alg_ID || 03 || 01 || KDF_hash_ID || KEK_alg_ID for AESKeyWrap || "Anonymous
1121        data = bytearray()
1122        data += encoder.encode(curve.value)[1:]
1123        data.append(pkalg)
1124        data += b'\x03\x01'
1125        data.append(self.halg)
1126        data.append(self.encalg)
1127        data += b'Anonymous Sender    '
1128        data += binascii.unhexlify(fingerprint.replace(' ', ''))
1129
1130        ckdf = ConcatKDFHash(algorithm=getattr(hashes, self.halg.name)(), length=self.encalg.key_size // 8, otherinfo=bytes(data), backend=default_backend())
1131        return ckdf.derive(s)
1132
1133
1134class PrivKey(PubKey):
1135    __privfields__ = ()
1136
1137    @property
1138    def __mpis__(self):
1139        for i in super(PrivKey, self).__mpis__:
1140            yield i
1141
1142        for i in self.__privfields__:
1143            yield i
1144
1145    def __init__(self):
1146        super(PrivKey, self).__init__()
1147
1148        self.s2k = String2Key()
1149        self.encbytes = bytearray()
1150        self.chksum = bytearray()
1151
1152        for field in self.__privfields__:
1153            setattr(self, field, MPI(0))
1154
1155    def __bytearray__(self):
1156        _bytes = bytearray()
1157        _bytes += super(PrivKey, self).__bytearray__()
1158
1159        _bytes += self.s2k.__bytearray__()
1160        if self.s2k:
1161            _bytes += self.encbytes
1162
1163        else:
1164            for field in self.__privfields__:
1165                _bytes += getattr(self, field).to_mpibytes()
1166
1167        if self.s2k.usage == 0:
1168            _bytes += self.chksum
1169
1170        return _bytes
1171
1172    def __len__(self):
1173        l = super(PrivKey, self).__len__() + len(self.s2k) + len(self.chksum)
1174        if self.s2k:
1175            l += len(self.encbytes)
1176
1177        else:
1178            l += sum(len(getattr(self, i)) for i in self.__privfields__)
1179
1180        return l
1181
1182    def __copy__(self):
1183        pk = super(PrivKey, self).__copy__()
1184        pk.s2k = copy.copy(self.s2k)
1185        pk.encbytes = copy.copy(self.encbytes)
1186        pk.chksum = copy.copy(self.chksum)
1187        return pk
1188
1189    @abc.abstractmethod
1190    def __privkey__(self):
1191        """return the requisite *PrivateKey class from the cryptography library"""
1192
1193    @abc.abstractmethod
1194    def _generate(self, key_size):
1195        """Generate a new PrivKey"""
1196
1197    def _compute_chksum(self):
1198        "Calculate the key checksum"
1199
1200    def publen(self):
1201        return super(PrivKey, self).__len__()
1202
1203    def encrypt_keyblob(self, passphrase, enc_alg, hash_alg):
1204        # PGPy will only ever use iterated and salted S2k mode
1205        self.s2k.usage = 254
1206        self.s2k.encalg = enc_alg
1207        self.s2k.specifier = String2KeyType.Iterated
1208        self.s2k.iv = enc_alg.gen_iv()
1209        self.s2k.halg = hash_alg
1210        self.s2k.salt = bytearray(os.urandom(8))
1211        self.s2k.count = hash_alg.tuned_count
1212
1213        # now that String-to-Key is ready to go, derive sessionkey from passphrase
1214        # and then unreference passphrase
1215        sessionkey = self.s2k.derive_key(passphrase)
1216        del passphrase
1217
1218        pt = bytearray()
1219        for pf in self.__privfields__:
1220            pt += getattr(self, pf).to_mpibytes()
1221
1222        # append a SHA-1 hash of the plaintext so far to the plaintext
1223        pt += hashlib.new('sha1', pt).digest()
1224
1225        # encrypt
1226        self.encbytes = bytearray(_encrypt(bytes(pt), bytes(sessionkey), enc_alg, bytes(self.s2k.iv)))
1227
1228        # delete pt and clear self
1229        del pt
1230        self.clear()
1231
1232    @abc.abstractmethod
1233    def decrypt_keyblob(self, passphrase):
1234        if not self.s2k:  # pragma: no cover
1235            # not encrypted
1236            return
1237
1238        # Encryption/decryption of the secret data is done in CFB mode using
1239        # the key created from the passphrase and the Initial Vector from the
1240        # packet.  A different mode is used with V3 keys (which are only RSA)
1241        # than with other key formats.  (...)
1242        #
1243        # With V4 keys, a simpler method is used.  All secret MPI values are
1244        # encrypted in CFB mode, including the MPI bitcount prefix.
1245
1246        # derive the session key from our passphrase, and then unreference passphrase
1247        sessionkey = self.s2k.derive_key(passphrase)
1248        del passphrase
1249
1250        # attempt to decrypt this key
1251        pt = _decrypt(bytes(self.encbytes), bytes(sessionkey), self.s2k.encalg, bytes(self.s2k.iv))
1252
1253        # check the hash to see if we decrypted successfully or not
1254        if self.s2k.usage == 254 and not pt[-20:] == hashlib.new('sha1', pt[:-20]).digest():
1255            # if the usage byte is 254, key material is followed by a 20-octet sha-1 hash of the rest
1256            # of the key material block
1257            raise PGPDecryptionError("Passphrase was incorrect!")
1258
1259        if self.s2k.usage == 255 and not self.bytes_to_int(pt[-2:]) == (sum(bytearray(pt[:-2])) % 65536):  # pragma: no cover
1260            # if the usage byte is 255, key material is followed by a 2-octet checksum of the rest
1261            # of the key material block
1262            raise PGPDecryptionError("Passphrase was incorrect!")
1263
1264        return bytearray(pt)
1265
1266    def sign(self, sigdata, hash_alg):
1267        return NotImplemented  # pragma: no cover
1268
1269    def clear(self):
1270        """delete and re-initialize all private components to zero"""
1271        for field in self.__privfields__:
1272            delattr(self, field)
1273            setattr(self, field, MPI(0))
1274
1275
1276class OpaquePrivKey(PrivKey, OpaquePubKey):  # pragma: no cover
1277    def __privkey__(self):
1278        return NotImplemented
1279
1280    def _generate(self, key_size):
1281        # return NotImplemented
1282        raise NotImplementedError()
1283
1284    def decrypt_keyblob(self, passphrase):
1285        return NotImplemented
1286
1287
1288class RSAPriv(PrivKey, RSAPub):
1289    __privfields__ = ('d', 'p', 'q', 'u')
1290
1291    def __privkey__(self):
1292        return rsa.RSAPrivateNumbers(self.p, self.q, self.d,
1293                                     rsa.rsa_crt_dmp1(self.d, self.p),
1294                                     rsa.rsa_crt_dmq1(self.d, self.q),
1295                                     rsa.rsa_crt_iqmp(self.p, self.q),
1296                                     rsa.RSAPublicNumbers(self.e, self.n)).private_key(default_backend())
1297
1298    def _compute_chksum(self):
1299        chs = sum(sum(bytearray(c.to_mpibytes())) for c in (self.d, self.p, self.q, self.u)) % 65536
1300        self.chksum = bytearray(self.int_to_bytes(chs, 2))
1301
1302    def _generate(self, key_size):
1303        if any(c != 0 for c in self):  # pragma: no cover
1304            raise PGPError("key is already populated")
1305
1306        # generate some big numbers!
1307        pk = rsa.generate_private_key(65537, key_size, default_backend())
1308        pkn = pk.private_numbers()
1309
1310        self.n = MPI(pkn.public_numbers.n)
1311        self.e = MPI(pkn.public_numbers.e)
1312        self.d = MPI(pkn.d)
1313        self.p = MPI(pkn.p)
1314        self.q = MPI(pkn.q)
1315        # from the RFC:
1316        # "- MPI of u, the multiplicative inverse of p, mod q."
1317        # or, simply, p^-1 mod p
1318        # rsa.rsa_crt_iqmp(p, q) normally computes q^-1 mod p,
1319        # so if we swap the values around we get the answer we want
1320        self.u = MPI(rsa.rsa_crt_iqmp(pkn.q, pkn.p))
1321
1322        del pkn
1323        del pk
1324
1325        self._compute_chksum()
1326
1327    def parse(self, packet):
1328        super(RSAPriv, self).parse(packet)
1329        self.s2k.parse(packet)
1330
1331        if not self.s2k:
1332            self.d = MPI(packet)
1333            self.p = MPI(packet)
1334            self.q = MPI(packet)
1335            self.u = MPI(packet)
1336
1337            if self.s2k.usage == 0:
1338                self.chksum = packet[:2]
1339                del packet[:2]
1340
1341        else:
1342            ##TODO: this needs to be bounded to the length of the encrypted key material
1343            self.encbytes = packet
1344
1345    def decrypt_keyblob(self, passphrase):
1346        kb = super(RSAPriv, self).decrypt_keyblob(passphrase)
1347        del passphrase
1348
1349        self.d = MPI(kb)
1350        self.p = MPI(kb)
1351        self.q = MPI(kb)
1352        self.u = MPI(kb)
1353
1354        if self.s2k.usage in [254, 255]:
1355            self.chksum = kb
1356            del kb
1357
1358    def sign(self, sigdata, hash_alg):
1359        return self.__privkey__().sign(sigdata, padding.PKCS1v15(), hash_alg)
1360
1361
1362class DSAPriv(PrivKey, DSAPub):
1363    __privfields__ = ('x',)
1364
1365    def __privkey__(self):
1366        params = dsa.DSAParameterNumbers(self.p, self.q, self.g)
1367        pn = dsa.DSAPublicNumbers(self.y, params)
1368        return dsa.DSAPrivateNumbers(self.x, pn).private_key(default_backend())
1369
1370    def _compute_chksum(self):
1371        chs = sum(bytearray(self.x.to_mpibytes())) % 65536
1372        self.chksum = bytearray(self.int_to_bytes(chs, 2))
1373
1374    def _generate(self, key_size):
1375        if any(c != 0 for c in self):  # pragma: no cover
1376            raise PGPError("key is already populated")
1377
1378        # generate some big numbers!
1379        pk = dsa.generate_private_key(key_size, default_backend())
1380        pkn = pk.private_numbers()
1381
1382        self.p = MPI(pkn.public_numbers.parameter_numbers.p)
1383        self.q = MPI(pkn.public_numbers.parameter_numbers.q)
1384        self.g = MPI(pkn.public_numbers.parameter_numbers.g)
1385        self.y = MPI(pkn.public_numbers.y)
1386        self.x = MPI(pkn.x)
1387
1388        del pkn
1389        del pk
1390
1391        self._compute_chksum()
1392
1393    def parse(self, packet):
1394        super(DSAPriv, self).parse(packet)
1395        self.s2k.parse(packet)
1396
1397        if not self.s2k:
1398            self.x = MPI(packet)
1399
1400        else:
1401            self.encbytes = packet
1402
1403        if self.s2k.usage in [0, 255]:
1404            self.chksum = packet[:2]
1405            del packet[:2]
1406
1407    def decrypt_keyblob(self, passphrase):
1408        kb = super(DSAPriv, self).decrypt_keyblob(passphrase)
1409        del passphrase
1410
1411        self.x = MPI(kb)
1412
1413        if self.s2k.usage in [254, 255]:
1414            self.chksum = kb
1415            del kb
1416
1417    def sign(self, sigdata, hash_alg):
1418        return self.__privkey__().sign(sigdata, hash_alg)
1419
1420
1421class ElGPriv(PrivKey, ElGPub):
1422    __privfields__ = ('x', )
1423
1424    def __privkey__(self):
1425        raise NotImplementedError()
1426
1427    def _compute_chksum(self):
1428        chs = sum(bytearray(self.x.to_mpibytes())) % 65536
1429        self.chksum = bytearray(self.int_to_bytes(chs, 2))
1430
1431    def _generate(self, key_size):
1432        raise NotImplementedError(PubKeyAlgorithm.ElGamal)
1433
1434    def parse(self, packet):
1435        super(ElGPriv, self).parse(packet)
1436        self.s2k.parse(packet)
1437
1438        if not self.s2k:
1439            self.x = MPI(packet)
1440
1441        else:
1442            self.encbytes = packet
1443
1444        if self.s2k.usage in [0, 255]:
1445            self.chksum = packet[:2]
1446            del packet[:2]
1447
1448    def decrypt_keyblob(self, passphrase):
1449        kb = super(ElGPriv, self).decrypt_keyblob(passphrase)
1450        del passphrase
1451
1452        self.x = MPI(kb)
1453
1454        if self.s2k.usage in [254, 255]:
1455            self.chksum = kb
1456            del kb
1457
1458
1459class ECDSAPriv(PrivKey, ECDSAPub):
1460    __privfields__ = ('s', )
1461
1462    def __privkey__(self):
1463        ecp = ec.EllipticCurvePublicNumbers(self.p.x, self.p.y, self.oid.curve())
1464        return ec.EllipticCurvePrivateNumbers(self.s, ecp).private_key(default_backend())
1465
1466    def _compute_chksum(self):
1467        chs = sum(bytearray(self.s.to_mpibytes())) % 65536
1468        self.chksum = bytearray(self.int_to_bytes(chs, 2))
1469
1470    def _generate(self, oid):
1471        if any(c != 0 for c in self):  # pragma: no cover
1472            raise PGPError("Key is already populated!")
1473
1474        self.oid = EllipticCurveOID(oid)
1475
1476        if not self.oid.can_gen:
1477            raise ValueError("Curve not currently supported: {}".format(oid.name))
1478
1479        pk = ec.generate_private_key(self.oid.curve(), default_backend())
1480        pubn = pk.public_key().public_numbers()
1481        self.p = ECPoint.from_values(self.oid.key_size, ECPointFormat.Standard, MPI(pubn.x), MPI(pubn.y))
1482        self.s = MPI(pk.private_numbers().private_value)
1483        self._compute_chksum()
1484
1485    def parse(self, packet):
1486        super(ECDSAPriv, self).parse(packet)
1487        self.s2k.parse(packet)
1488
1489        if not self.s2k:
1490            self.s = MPI(packet)
1491
1492            if self.s2k.usage == 0:
1493                self.chksum = packet[:2]
1494                del packet[:2]
1495        else:
1496            ##TODO: this needs to be bounded to the length of the encrypted key material
1497            self.encbytes = packet
1498
1499    def decrypt_keyblob(self, passphrase):
1500        kb = super(ECDSAPriv, self).decrypt_keyblob(passphrase)
1501        del passphrase
1502        self.s = MPI(kb)
1503
1504    def sign(self, sigdata, hash_alg):
1505        return self.__privkey__().sign(sigdata, ec.ECDSA(hash_alg))
1506
1507
1508class EdDSAPriv(PrivKey, EdDSAPub):
1509    __privfields__ = ('s', )
1510
1511    def __privkey__(self):
1512        s = self.int_to_bytes(self.s, (self.oid.key_size + 7) // 8)
1513        return ed25519.Ed25519PrivateKey.from_private_bytes(s)
1514
1515    def _compute_chksum(self):
1516        chs = sum(bytearray(self.s.to_mpibytes())) % 65536
1517        self.chksum = bytearray(self.int_to_bytes(chs, 2))
1518
1519    def _generate(self, oid):
1520        if any(c != 0 for c in self):  # pragma: no cover
1521            raise PGPError("Key is already populated!")
1522
1523        self.oid = EllipticCurveOID(oid)
1524
1525        if self.oid != EllipticCurveOID.Ed25519:
1526            raise ValueError("EdDSA only supported with {}".format(EllipticCurveOID.Ed25519))
1527
1528        pk = ed25519.Ed25519PrivateKey.generate()
1529        x = pk.public_key().public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw)
1530        self.p = ECPoint.from_values(self.oid.key_size, ECPointFormat.Native, x)
1531        self.s = MPI(self.bytes_to_int(pk.private_bytes(
1532            encoding=serialization.Encoding.Raw,
1533            format=serialization.PrivateFormat.Raw,
1534            encryption_algorithm=serialization.NoEncryption()
1535        )))
1536        self._compute_chksum()
1537
1538    def parse(self, packet):
1539        super(EdDSAPriv, self).parse(packet)
1540        self.s2k.parse(packet)
1541
1542        if not self.s2k:
1543            self.s = MPI(packet)
1544            if self.s2k.usage == 0:
1545                self.chksum = packet[:2]
1546                del packet[:2]
1547        else:
1548            ##TODO: this needs to be bounded to the length of the encrypted key material
1549            self.encbytes = packet
1550
1551    def decrypt_keyblob(self, passphrase):
1552        kb = super(EdDSAPriv, self).decrypt_keyblob(passphrase)
1553        del passphrase
1554        self.s = MPI(kb)
1555
1556    def sign(self, sigdata, hash_alg):
1557        # GnuPG requires a pre-hashing with EdDSA
1558        # https://tools.ietf.org/html/draft-ietf-openpgp-rfc4880bis-06#section-14.8
1559        digest = hashes.Hash(hash_alg, backend=default_backend())
1560        digest.update(sigdata)
1561        sigdata = digest.finalize()
1562        return self.__privkey__().sign(sigdata)
1563
1564
1565class ECDHPriv(ECDSAPriv, ECDHPub):
1566    def __bytearray__(self):
1567        _b = ECDHPub.__bytearray__(self)
1568        _b += self.s2k.__bytearray__()
1569        if not self.s2k:
1570            _b += self.s.to_mpibytes()
1571            if self.s2k.usage == 0:
1572                _b += self.chksum
1573        else:
1574            _b += self.encbytes
1575        return _b
1576
1577    def __len__(self):
1578        l = ECDHPub.__len__(self) + len(self.s2k) + len(self.chksum)
1579        if self.s2k:
1580            l += len(self.encbytes)
1581        else:
1582            l += sum(len(getattr(self, i)) for i in self.__privfields__)
1583        return l
1584
1585    def __privkey__(self):
1586        if self.oid == EllipticCurveOID.Curve25519:
1587            # NOTE: openssl and GPG don't use the same endianness for Curve25519 secret value
1588            s = self.int_to_bytes(self.s, (self.oid.key_size + 7) // 8, 'little')
1589            return x25519.X25519PrivateKey.from_private_bytes(s)
1590        else:
1591            return ECDSAPriv.__privkey__(self)
1592
1593    def _generate(self, oid):
1594        _oid = EllipticCurveOID(oid)
1595        if _oid == EllipticCurveOID.Curve25519:
1596            if any(c != 0 for c in self):  # pragma: no cover
1597                raise PGPError("Key is already populated!")
1598            self.oid = _oid
1599            pk = x25519.X25519PrivateKey.generate()
1600            x = pk.public_key().public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw)
1601            self.p = ECPoint.from_values(self.oid.key_size, ECPointFormat.Native, x)
1602            # NOTE: openssl and GPG don't use the same endianness for Curve25519 secret value
1603            self.s = MPI(self.bytes_to_int(pk.private_bytes(
1604                encoding=serialization.Encoding.Raw,
1605                format=serialization.PrivateFormat.Raw,
1606                encryption_algorithm=serialization.NoEncryption()
1607            ), 'little'))
1608            self._compute_chksum()
1609        else:
1610            ECDSAPriv._generate(self, oid)
1611        self.kdf.halg = self.oid.kdf_halg
1612        self.kdf.encalg = self.oid.kek_alg
1613
1614    def publen(self):
1615        return ECDHPub.__len__(self)
1616
1617    def parse(self, packet):
1618        ECDHPub.parse(self, packet)
1619        self.s2k.parse(packet)
1620
1621        if not self.s2k:
1622            self.s = MPI(packet)
1623            if self.s2k.usage == 0:
1624                self.chksum = packet[:2]
1625                del packet[:2]
1626        else:
1627            ##TODO: this needs to be bounded to the length of the encrypted key material
1628            self.encbytes = packet
1629
1630    def sign(self, sigdata, hash_alg):
1631        raise PGPError("Cannot sign with an ECDH key")
1632
1633
1634class CipherText(MPIs):
1635    def __init__(self):
1636        super(CipherText, self).__init__()
1637        for i in self.__mpis__:
1638            setattr(self, i, MPI(0))
1639
1640    @classmethod
1641    @abc.abstractmethod
1642    def encrypt(cls, encfn, *args):
1643        """create and populate a concrete CipherText class instance"""
1644
1645    @abc.abstractmethod
1646    def decrypt(self, decfn, *args):
1647        """decrypt the ciphertext contained in this CipherText instance"""
1648
1649    def __bytearray__(self):
1650        _bytes = bytearray()
1651        for i in self:
1652            _bytes += i.to_mpibytes()
1653        return _bytes
1654
1655
1656class RSACipherText(CipherText):
1657    __mpis__ = ('me_mod_n', )
1658
1659    @classmethod
1660    def encrypt(cls, encfn, *args):
1661        ct = cls()
1662        ct.me_mod_n = MPI(cls.bytes_to_int(encfn(*args)))
1663        return ct
1664
1665    def decrypt(self, decfn, *args):
1666        return decfn(*args)
1667
1668    def parse(self, packet):
1669        self.me_mod_n = MPI(packet)
1670
1671
1672class ElGCipherText(CipherText):
1673    __mpis__ = ('gk_mod_p', 'myk_mod_p')
1674
1675    @classmethod
1676    def encrypt(cls, encfn, *args):
1677        raise NotImplementedError()
1678
1679    def decrypt(self, decfn, *args):
1680        raise NotImplementedError()
1681
1682    def parse(self, packet):
1683        self.gk_mod_p = MPI(packet)
1684        self.myk_mod_p = MPI(packet)
1685
1686
1687class ECDHCipherText(CipherText):
1688    __mpis__ = ('p',)
1689
1690    @classmethod
1691    def encrypt(cls, pk, *args):
1692        """
1693        For convenience, the synopsis of the encoding method is given below;
1694        however, this section, [NIST-SP800-56A], and [RFC3394] are the
1695        normative sources of the definition.
1696
1697            Obtain the authenticated recipient public key R
1698            Generate an ephemeral key pair {v, V=vG}
1699            Compute the shared point S = vR;
1700            m = symm_alg_ID || session key || checksum || pkcs5_padding;
1701            curve_OID_len = (byte)len(curve_OID);
1702            Param = curve_OID_len || curve_OID || public_key_alg_ID || 03
1703            || 01 || KDF_hash_ID || KEK_alg_ID for AESKeyWrap || "Anonymous
1704            Sender    " || recipient_fingerprint;
1705            Z_len = the key size for the KEK_alg_ID used with AESKeyWrap
1706            Compute Z = KDF( S, Z_len, Param );
1707            Compute C = AESKeyWrap( Z, m ) as per [RFC3394]
1708            VB = convert point V to the octet string
1709            Output (MPI(VB) || len(C) || C).
1710
1711        The decryption is the inverse of the method given.  Note that the
1712        recipient obtains the shared secret by calculating
1713        """
1714        # *args should be:
1715        # - m
1716        #
1717        _m, = args
1718
1719        # m may need to be PKCS5-padded
1720        padder = PKCS7(64).padder()
1721        m = padder.update(_m) + padder.finalize()
1722
1723        km = pk.keymaterial
1724        ct = cls()
1725
1726        # generate ephemeral key pair and keep public key in ct
1727        # use private key to compute the shared point "s"
1728        if km.oid == EllipticCurveOID.Curve25519:
1729            v = x25519.X25519PrivateKey.generate()
1730            x = v.public_key().public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw)
1731            ct.p = ECPoint.from_values(km.oid.key_size, ECPointFormat.Native, x)
1732            s = v.exchange(km.__pubkey__())
1733        else:
1734            v = ec.generate_private_key(km.oid.curve(), default_backend())
1735            x = MPI(v.public_key().public_numbers().x)
1736            y = MPI(v.public_key().public_numbers().y)
1737            ct.p = ECPoint.from_values(km.oid.key_size, ECPointFormat.Standard, x, y)
1738            s = v.exchange(ec.ECDH(), km.__pubkey__())
1739
1740        # derive the wrapping key
1741        z = km.kdf.derive_key(s, km.oid, PubKeyAlgorithm.ECDH, pk.fingerprint)
1742
1743        # compute C
1744        ct.c = aes_key_wrap(z, m, default_backend())
1745
1746        return ct
1747
1748    def decrypt(self, pk, *args):
1749        km = pk.keymaterial
1750        if km.oid == EllipticCurveOID.Curve25519:
1751            v = x25519.X25519PublicKey.from_public_bytes(self.p.x)
1752            s = km.__privkey__().exchange(v)
1753        else:
1754            # assemble the public component of ephemeral key v
1755            v = ec.EllipticCurvePublicNumbers(self.p.x, self.p.y, km.oid.curve()).public_key(default_backend())
1756            # compute s using the inverse of how it was derived during encryption
1757            s = km.__privkey__().exchange(ec.ECDH(), v)
1758
1759        # derive the wrapping key
1760        z = km.kdf.derive_key(s, km.oid, PubKeyAlgorithm.ECDH, pk.fingerprint)
1761
1762        # unwrap and unpad m
1763        _m = aes_key_unwrap(z, self.c, default_backend())
1764
1765        padder = PKCS7(64).unpadder()
1766        return padder.update(_m) + padder.finalize()
1767
1768    def __init__(self):
1769        super(ECDHCipherText, self).__init__()
1770        self.c = bytearray(0)
1771
1772    def __bytearray__(self):
1773        _bytes = bytearray()
1774        _bytes += self.p.to_mpibytes()
1775        _bytes.append(len(self.c))
1776        _bytes += self.c
1777        return _bytes
1778
1779    def parse(self, packet):
1780        # read ephemeral public key
1781        self.p = ECPoint(packet)
1782        # read signature value
1783        clen = packet[0]
1784        del packet[0]
1785        self.c += packet[:clen]
1786        del packet[:clen]
1787