1# Copyright (c) 2013-2020 by Ron Frederick <ronf@timeheart.net> and others.
2#
3# This program and the accompanying materials are made available under
4# the terms of the Eclipse Public License v2.0 which accompanies this
5# distribution and is available at:
6#
7#     http://www.eclipse.org/legal/epl-2.0/
8#
9# This program may also be made available under the following secondary
10# licenses when the conditions for such availability set forth in the
11# Eclipse Public License v2.0 are satisfied:
12#
13#    GNU General Public License, Version 2.0, or any later versions of
14#    that license
15#
16# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
17#
18# Contributors:
19#     Ron Frederick - initial implementation, API, and documentation
20
21"""DSA public key encryption handler"""
22
23from .asn1 import ASN1DecodeError, ObjectIdentifier, der_encode, der_decode
24from .crypto import DSAPrivateKey, DSAPublicKey
25from .misc import all_ints
26from .packet import MPInt, String
27from .public_key import SSHKey, SSHOpenSSHCertificateV01, KeyExportError
28from .public_key import register_public_key_alg, register_certificate_alg
29from .public_key import register_x509_certificate_alg
30
31
32class _DSAKey(SSHKey):
33    """Handler for DSA public key encryption"""
34
35    algorithm = b'ssh-dss'
36    default_hash_alg = 'sha1'
37    pem_name = b'DSA'
38    pkcs8_oid = ObjectIdentifier('1.2.840.10040.4.1')
39    sig_algorithms = (algorithm,)
40    x509_algorithms = (b'x509v3-' + algorithm,)
41    all_sig_algorithms = set(sig_algorithms)
42
43    def __eq__(self, other):
44        # This isn't protected access - both objects are _DSAKey instances
45        # pylint: disable=protected-access
46
47        return (isinstance(other, type(self)) and
48                self._key.p == other._key.p and
49                self._key.q == other._key.q and
50                self._key.g == other._key.g and
51                self._key.y == other._key.y and
52                self._key.x == other._key.x)
53
54    def __hash__(self):
55        return hash((self._key.p, self._key.q, self._key.g,
56                     self._key.y, self._key.x))
57
58    @classmethod
59    def generate(cls, _algorithm):
60        """Generate a new DSA private key"""
61
62        return cls(DSAPrivateKey.generate(key_size=1024))
63
64    @classmethod
65    def make_private(cls, p, q, g, y, x):
66        """Construct a DSA private key"""
67
68        return cls(DSAPrivateKey.construct(p, q, g, y, x))
69
70    @classmethod
71    def make_public(cls, p, q, g, y):
72        """Construct a DSA public key"""
73
74        return cls(DSAPublicKey.construct(p, q, g, y))
75
76    @classmethod
77    def decode_pkcs1_private(cls, key_data):
78        """Decode a PKCS#1 format DSA private key"""
79
80        if (isinstance(key_data, tuple) and len(key_data) == 6 and
81                all_ints(key_data) and key_data[0] == 0):
82            return key_data[1:]
83        else:
84            return None
85
86    @classmethod
87    def decode_pkcs1_public(cls, key_data):
88        """Decode a PKCS#1 format DSA public key"""
89
90        if (isinstance(key_data, tuple) and len(key_data) == 4 and
91                all_ints(key_data)):
92            y, p, q, g = key_data
93            return p, q, g, y
94        else:
95            return None
96
97    @classmethod
98    def decode_pkcs8_private(cls, alg_params, data):
99        """Decode a PKCS#8 format DSA private key"""
100
101        try:
102            x = der_decode(data)
103        except ASN1DecodeError:
104            return None
105
106        if (isinstance(alg_params, tuple) and len(alg_params) == 3 and
107                all_ints(alg_params) and isinstance(x, int)):
108            p, q, g = alg_params
109            y = pow(g, x, p)
110            return p, q, g, y, x
111        else:
112            return None
113
114    @classmethod
115    def decode_pkcs8_public(cls, alg_params, data):
116        """Decode a PKCS#8 format DSA public key"""
117
118        try:
119            y = der_decode(data)
120        except ASN1DecodeError:
121            return None
122
123        if (isinstance(alg_params, tuple) and len(alg_params) == 3 and
124                all_ints(alg_params) and isinstance(y, int)):
125            p, q, g = alg_params
126            return p, q, g, y
127        else:
128            return None
129
130    @classmethod
131    def decode_ssh_private(cls, packet):
132        """Decode an SSH format DSA private key"""
133
134        p = packet.get_mpint()
135        q = packet.get_mpint()
136        g = packet.get_mpint()
137        y = packet.get_mpint()
138        x = packet.get_mpint()
139
140        return p, q, g, y, x
141
142    @classmethod
143    def decode_ssh_public(cls, packet):
144        """Decode an SSH format DSA public key"""
145
146        p = packet.get_mpint()
147        q = packet.get_mpint()
148        g = packet.get_mpint()
149        y = packet.get_mpint()
150
151        return p, q, g, y
152
153    def encode_pkcs1_private(self):
154        """Encode a PKCS#1 format DSA private key"""
155
156        if not self._key.x:
157            raise KeyExportError('Key is not private')
158
159        return (0, self._key.p, self._key.q, self._key.g,
160                self._key.y, self._key.x)
161
162    def encode_pkcs1_public(self):
163        """Encode a PKCS#1 format DSA public key"""
164
165        return (self._key.y, self._key.p, self._key.q, self._key.g)
166
167    def encode_pkcs8_private(self):
168        """Encode a PKCS#8 format DSA private key"""
169
170        if not self._key.x:
171            raise KeyExportError('Key is not private')
172
173        return (self._key.p, self._key.q, self._key.g), der_encode(self._key.x)
174
175    def encode_pkcs8_public(self):
176        """Encode a PKCS#8 format DSA public key"""
177
178        return (self._key.p, self._key.q, self._key.g), der_encode(self._key.y)
179
180    def encode_ssh_private(self):
181        """Encode an SSH format DSA private key"""
182
183        if not self._key.x:
184            raise KeyExportError('Key is not private')
185
186        return b''.join((MPInt(self._key.p), MPInt(self._key.q),
187                         MPInt(self._key.g), MPInt(self._key.y),
188                         MPInt(self._key.x)))
189
190    def encode_ssh_public(self):
191        """Encode an SSH format DSA public key"""
192
193        return b''.join((MPInt(self._key.p), MPInt(self._key.q),
194                         MPInt(self._key.g), MPInt(self._key.y)))
195
196    def encode_agent_cert_private(self):
197        """Encode DSA certificate private key data for agent"""
198
199        if not self._key.x:
200            raise KeyExportError('Key is not private')
201
202        return MPInt(self._key.x)
203
204    def sign_ssh(self, data, sig_algorithm):
205        """Compute an SSH-encoded signature of the specified data"""
206
207        # pylint: disable=unused-argument
208
209        if not self._key.x:
210            raise ValueError('Private key needed for signing')
211
212        r, s = der_decode(self._key.sign(data, 'sha1'))
213        return String(r.to_bytes(20, 'big') + s.to_bytes(20, 'big'))
214
215    def verify_ssh(self, data, sig_algorithm, packet):
216        """Verify an SSH-encoded signature of the specified data"""
217
218        # pylint: disable=unused-argument
219
220        sig = packet.get_string()
221        packet.check_end()
222
223        if len(sig) != 40:
224            return False
225
226        r = int.from_bytes(sig[:20], 'big')
227        s = int.from_bytes(sig[20:], 'big')
228
229        return self._key.verify(data, der_encode((r, s)), 'sha1')
230
231
232register_public_key_alg(b'ssh-dss', _DSAKey, False)
233
234register_certificate_alg(1, b'ssh-dss', b'ssh-dss-cert-v01@openssh.com',
235                         _DSAKey, SSHOpenSSHCertificateV01, False)
236
237for alg in _DSAKey.x509_algorithms:
238    register_x509_certificate_alg(alg, False)
239