1# Copyright (c) 2015-2020 by Ron Frederick <ronf@timeheart.net> and others.
2#
3# This program and the accompanying materials are made available under
4# the terms of the Eclipse Public License v2.0 which accompanies this
5# distribution and is available at:
6#
7#     http://www.eclipse.org/legal/epl-2.0/
8#
9# This program may also be made available under the following secondary
10# licenses when the conditions for such availability set forth in the
11# Eclipse Public License v2.0 are satisfied:
12#
13#    GNU General Public License, Version 2.0, or any later versions of
14#    that license
15#
16# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
17#
18# Contributors:
19#     Ron Frederick - initial implementation, API, and documentation
20
21"""Unit tests for key exchange"""
22
23import asyncio
24import unittest
25
26from hashlib import sha1
27
28import asyncssh
29
30from asyncssh.crypto import curve25519_available, curve448_available
31from asyncssh.crypto import Curve25519DH, Curve448DH, ECDH
32from asyncssh.kex_dh import MSG_KEXDH_INIT, MSG_KEXDH_REPLY
33from asyncssh.kex_dh import MSG_KEX_DH_GEX_REQUEST, MSG_KEX_DH_GEX_GROUP
34from asyncssh.kex_dh import MSG_KEX_DH_GEX_INIT, MSG_KEX_DH_GEX_REPLY, _KexDHGex
35from asyncssh.kex_dh import MSG_KEX_ECDH_INIT, MSG_KEX_ECDH_REPLY
36from asyncssh.kex_dh import MSG_KEXGSS_INIT, MSG_KEXGSS_COMPLETE
37from asyncssh.kex_dh import MSG_KEXGSS_ERROR
38from asyncssh.kex_rsa import MSG_KEXRSA_PUBKEY, MSG_KEXRSA_SECRET
39from asyncssh.kex_rsa import MSG_KEXRSA_DONE
40from asyncssh.gss import GSSClient, GSSServer
41from asyncssh.kex import register_kex_alg, get_kex_algs, get_kex
42from asyncssh.packet import SSHPacket, Boolean, Byte, MPInt, String
43from asyncssh.public_key import decode_ssh_public_key
44
45from .util import asynctest, gss_available, patch_gss
46from .util import AsyncTestCase, ConnectionStub
47
48
49class _KexConnectionStub(ConnectionStub):
50    """Connection stub class to test key exchange"""
51
52    def __init__(self, alg, gss, peer, server=False):
53        super().__init__(peer, server)
54
55        self._gss = gss
56        self._key_waiter = asyncio.Future()
57
58        self._kex = get_kex(self, alg)
59
60    def start(self):
61        """Start key exchange"""
62
63        self._kex.start()
64
65    def connection_lost(self, exc):
66        """Handle the closing of a connection"""
67
68        raise NotImplementedError
69
70    def enable_gss_kex_auth(self):
71        """Ignore request to enable GSS key exchange authentication"""
72
73    def process_packet(self, data):
74        """Process an incoming packet"""
75
76        packet = SSHPacket(data)
77        pkttype = packet.get_byte()
78        self._kex.process_packet(pkttype, None, packet)
79
80    def get_hash_prefix(self):
81        """Return the bytes used in calculating unique connection hashes"""
82
83        # pylint: disable=no-self-use
84
85        return b'prefix'
86
87    def send_newkeys(self, k, h):
88        """Handle a request to send a new keys message"""
89
90        self._key_waiter.set_result(self._kex.compute_key(k, h, b'A', h, 128))
91
92    async def get_key(self):
93        """Return generated key data"""
94
95        return await self._key_waiter
96
97    def get_gss_context(self):
98        """Return the GSS context associated with this connection"""
99
100        return self._gss
101
102    def simulate_dh_init(self, e):
103        """Simulate receiving a DH init packet"""
104
105        self.process_packet(Byte(MSG_KEXDH_INIT) + MPInt(e))
106
107    def simulate_dh_reply(self, host_key_data, f, sig):
108        """Simulate receiving a DH reply packet"""
109
110        self.process_packet(b''.join((Byte(MSG_KEXDH_REPLY),
111                                      String(host_key_data),
112                                      MPInt(f), String(sig))))
113
114    def simulate_dh_gex_group(self, p, g):
115        """Simulate receiving a DH GEX group packet"""
116
117        self.process_packet(Byte(MSG_KEX_DH_GEX_GROUP) + MPInt(p) + MPInt(g))
118
119    def simulate_dh_gex_init(self, e):
120        """Simulate receiving a DH GEX init packet"""
121
122        self.process_packet(Byte(MSG_KEX_DH_GEX_INIT) + MPInt(e))
123
124    def simulate_dh_gex_reply(self, host_key_data, f, sig):
125        """Simulate receiving a DH GEX reply packet"""
126
127        self.process_packet(b''.join((Byte(MSG_KEX_DH_GEX_REPLY),
128                                      String(host_key_data),
129                                      MPInt(f), String(sig))))
130
131    def simulate_gss_complete(self, f, sig):
132        """Simulate receiving a GSS complete packet"""
133
134        self.process_packet(b''.join((Byte(MSG_KEXGSS_COMPLETE), MPInt(f),
135                                      String(sig), Boolean(False))))
136
137    def simulate_ecdh_init(self, client_pub):
138        """Simulate receiving an ECDH init packet"""
139
140        self.process_packet(Byte(MSG_KEX_ECDH_INIT) + String(client_pub))
141
142    def simulate_ecdh_reply(self, host_key_data, server_pub, sig):
143        """Simulate receiving ab ECDH reply packet"""
144
145        self.process_packet(b''.join((Byte(MSG_KEX_ECDH_REPLY),
146                                      String(host_key_data),
147                                      String(server_pub), String(sig))))
148
149    def simulate_rsa_pubkey(self, host_key_data, trans_key_data):
150        """Simulate receiving an RSA pubkey packet"""
151
152        self.process_packet(Byte(MSG_KEXRSA_PUBKEY) + String(host_key_data) +
153                            String(trans_key_data))
154
155    def simulate_rsa_secret(self, encrypted_k):
156        """Simulate receiving an RSA secret packet"""
157
158        self.process_packet(Byte(MSG_KEXRSA_SECRET) + String(encrypted_k))
159
160    def simulate_rsa_done(self, sig):
161        """Simulate receiving an RSA done packet"""
162
163        self.process_packet(Byte(MSG_KEXRSA_DONE) + String(sig))
164
165
166class _KexClientStub(_KexConnectionStub):
167    """Stub class for client connection"""
168
169    @classmethod
170    def make_pair(cls, alg, gss_host=None):
171        """Make a client and server connection pair to test key exchange"""
172
173        client_conn = cls(alg, gss_host)
174        return client_conn, client_conn.get_peer()
175
176    def __init__(self, alg, gss_host):
177        server_conn = _KexServerStub(alg, gss_host, self)
178
179        if gss_host:
180            gss = GSSClient(gss_host, 'delegate' in gss_host)
181        else:
182            gss = None
183
184        super().__init__(alg, gss, server_conn)
185
186    def connection_lost(self, exc):
187        """Handle the closing of a connection"""
188
189        if exc and not self._key_waiter.done():
190            self._key_waiter.set_exception(exc)
191
192        self.close()
193
194    def validate_server_host_key(self, host_key_data):
195        """Validate and return the server's host key"""
196
197        # pylint: disable=no-self-use
198
199        return decode_ssh_public_key(host_key_data)
200
201
202class _KexServerStub(_KexConnectionStub):
203    """Stub class for server connection"""
204
205    def __init__(self, alg, gss_host, peer):
206        gss = GSSServer(gss_host) if gss_host else None
207        super().__init__(alg, gss, peer, True)
208
209        if gss_host and 'no_host_key' in gss_host:
210            self._server_host_key = None
211        else:
212            priv_key = asyncssh.generate_private_key('ssh-rsa')
213            self._server_host_key = asyncssh.load_keypairs(priv_key)[0]
214
215    def connection_lost(self, exc):
216        """Handle the closing of a connection"""
217
218        if self._peer:
219            self._peer.connection_lost(exc)
220
221        self.close()
222
223    def get_server_host_key(self):
224        """Return the server host key"""
225
226        return self._server_host_key
227
228
229@patch_gss
230class _TestKex(AsyncTestCase):
231    """Unit tests for kex module"""
232
233    async def _check_kex(self, alg, gss_host=None):
234        """Unit test key exchange"""
235
236        client_conn, server_conn = _KexClientStub.make_pair(alg, gss_host)
237
238        try:
239            client_conn.start()
240            server_conn.start()
241
242            self.assertEqual((await client_conn.get_key()),
243                             (await server_conn.get_key()))
244        finally:
245            client_conn.close()
246            server_conn.close()
247
248    @asynctest
249    async def test_key_exchange_algs(self):
250        """Unit test key exchange algorithms"""
251
252        for alg in get_kex_algs():
253            with self.subTest(alg=alg):
254                if alg.startswith(b'gss-'):
255                    if gss_available: # pragma: no branch
256                        await self._check_kex(alg + b'-mech', '1')
257                else:
258                    await self._check_kex(alg)
259
260        if gss_available: # pragma: no branch
261            for steps in range(4):
262                with self.subTest('GSS key exchange', steps=steps):
263                    await self._check_kex(b'gss-group1-sha1-mech', str(steps))
264
265            with self.subTest('GSS with credential delegation'):
266                await self._check_kex(b'gss-group1-sha1-mech', '1,delegate')
267
268            with self.subTest('GSS with no host key'):
269                await self._check_kex(b'gss-group1-sha1-mech', '1,no_host_key')
270
271            with self.subTest('GSS with full host principal'):
272                await self._check_kex(b'gss-group1-sha1-mech', 'host/1@TEST')
273
274    @asynctest
275    async def test_dh_gex_old(self):
276        """Unit test old DH group exchange request"""
277
278        register_kex_alg(b'dh-gex-sha1-1024', _KexDHGex, sha1, (1024,), True)
279        register_kex_alg(b'dh-gex-sha1-2048', _KexDHGex, sha1, (2048,), True)
280
281        for size in (b'1024', b'2048'):
282            with self.subTest('Old DH group exchange', size=size):
283                await self._check_kex(b'dh-gex-sha1-' + size)
284
285    @asynctest
286    async def test_dh_gex(self):
287        """Unit test old DH group exchange request"""
288
289        register_kex_alg(b'dh-gex-sha1-1024-1536', _KexDHGex, sha1,
290                         (1024, 1536), True)
291        register_kex_alg(b'dh-gex-sha1-1536-3072', _KexDHGex, sha1,
292                         (1536, 3072), True)
293        register_kex_alg(b'dh-gex-sha1-2560-2560', _KexDHGex, sha1,
294                         (2560, 2560), True)
295        register_kex_alg(b'dh-gex-sha1-2560-4096', _KexDHGex, sha1,
296                         (2560, 4096), True)
297        register_kex_alg(b'dh-gex-sha1-9216-9216', _KexDHGex, sha1,
298                         (9216, 9216), True)
299
300        for size in (b'1024-1536', b'1536-3072', b'2560-2560',
301                     b'2560-4096', b'9216-9216'):
302            with self.subTest('Old DH group exchange', size=size):
303                await self._check_kex(b'dh-gex-sha1-' + size)
304
305    @asynctest
306    async def test_dh_errors(self):
307        """Unit test error conditions in DH key exchange"""
308
309        client_conn, server_conn = \
310            _KexClientStub.make_pair(b'diffie-hellman-group14-sha1')
311
312        host_key = server_conn.get_server_host_key()
313
314        with self.subTest('Init sent to client'):
315            with self.assertRaises(asyncssh.ProtocolError):
316                client_conn.process_packet(Byte(MSG_KEXDH_INIT))
317
318        with self.subTest('Reply sent to server'):
319            with self.assertRaises(asyncssh.ProtocolError):
320                server_conn.process_packet(Byte(MSG_KEXDH_REPLY))
321
322        with self.subTest('Invalid e value'):
323            with self.assertRaises(asyncssh.ProtocolError):
324                server_conn.simulate_dh_init(0)
325
326        with self.subTest('Invalid f value'):
327            with self.assertRaises(asyncssh.ProtocolError):
328                client_conn.start()
329                client_conn.simulate_dh_reply(host_key.public_data, 0, b'')
330
331        with self.subTest('Invalid signature'):
332            with self.assertRaises(asyncssh.KeyExchangeFailed):
333                client_conn.start()
334                client_conn.simulate_dh_reply(host_key.public_data, 1, b'')
335
336        client_conn.close()
337        server_conn.close()
338
339    @asynctest
340    async def test_dh_gex_errors(self):
341        """Unit test error conditions in DH group exchange"""
342
343        client_conn, server_conn = \
344            _KexClientStub.make_pair(b'diffie-hellman-group-exchange-sha1')
345
346        with self.subTest('Request sent to client'):
347            with self.assertRaises(asyncssh.ProtocolError):
348                client_conn.process_packet(Byte(MSG_KEX_DH_GEX_REQUEST))
349
350        with self.subTest('Group sent to server'):
351            with self.assertRaises(asyncssh.ProtocolError):
352                server_conn.simulate_dh_gex_group(1, 2)
353
354        with self.subTest('Init sent to client'):
355            with self.assertRaises(asyncssh.ProtocolError):
356                client_conn.simulate_dh_gex_init(1)
357
358        with self.subTest('Init sent before group'):
359            with self.assertRaises(asyncssh.ProtocolError):
360                server_conn.simulate_dh_gex_init(1)
361
362        with self.subTest('Reply sent to server'):
363            with self.assertRaises(asyncssh.ProtocolError):
364                server_conn.simulate_dh_gex_reply(b'', 1, b'')
365
366        with self.subTest('Reply sent before group'):
367            with self.assertRaises(asyncssh.ProtocolError):
368                client_conn.simulate_dh_gex_reply(b'', 1, b'')
369
370        client_conn.close()
371        server_conn.close()
372
373    @unittest.skipUnless(gss_available, 'GSS not available')
374    @asynctest
375    async def test_gss_errors(self):
376        """Unit test error conditions in GSS key exchange"""
377
378        client_conn, server_conn = \
379            _KexClientStub.make_pair(b'gss-group1-sha1-mech', '3')
380
381        with self.subTest('Init sent to client'):
382            with self.assertRaises(asyncssh.ProtocolError):
383                client_conn.process_packet(Byte(MSG_KEXGSS_INIT))
384
385        with self.subTest('Complete sent to server'):
386            with self.assertRaises(asyncssh.ProtocolError):
387                server_conn.process_packet(Byte(MSG_KEXGSS_COMPLETE))
388
389        with self.subTest('Exchange failed to complete'):
390            with self.assertRaises(asyncssh.ProtocolError):
391                client_conn.simulate_gss_complete(1, b'succeed')
392
393        with self.subTest('Error sent to server'):
394            with self.assertRaises(asyncssh.ProtocolError):
395                server_conn.process_packet(Byte(MSG_KEXGSS_ERROR))
396
397        client_conn.close()
398        server_conn.close()
399
400        with self.subTest('Signature verification failure'):
401            with self.assertRaises(asyncssh.KeyExchangeFailed):
402                await self._check_kex(b'gss-group1-sha1-mech', '0,fail')
403
404        with self.subTest('Empty token in init'):
405            with self.assertRaises(asyncssh.ProtocolError):
406                await self._check_kex(b'gss-group1-sha1-mech', '0,empty_init')
407
408        with self.subTest('Empty token in continue'):
409            with self.assertRaises(asyncssh.ProtocolError):
410                await self._check_kex(b'gss-group1-sha1-mech',
411                                      '1,empty_continue')
412
413        with self.subTest('Token after complete'):
414            with self.assertRaises(asyncssh.ProtocolError):
415                await self._check_kex(b'gss-group1-sha1-mech',
416                                      '0,continue_token')
417
418        for steps in range(2):
419            with self.subTest('Token after complete', steps=steps):
420                with self.assertRaises(asyncssh.ProtocolError):
421                    await self._check_kex(b'gss-group1-sha1-mech',
422                                          str(steps) + ',extra_token')
423
424        with self.subTest('Context not secure'):
425            with self.assertRaises(asyncssh.ProtocolError):
426                await self._check_kex(b'gss-group1-sha1-mech',
427                                      '1,no_server_integrity')
428
429        with self.subTest('GSS error'):
430            with self.assertRaises(asyncssh.KeyExchangeFailed):
431                await self._check_kex(b'gss-group1-sha1-mech', '1,step_error')
432
433        with self.subTest('GSS error with error token'):
434            with self.assertRaises(asyncssh.KeyExchangeFailed):
435                await self._check_kex(b'gss-group1-sha1-mech',
436                                      '1,step_error,errtok')
437
438    @asynctest
439    async def test_ecdh_errors(self):
440        """Unit test error conditions in ECDH key exchange"""
441
442        client_conn, server_conn = \
443            _KexClientStub.make_pair(b'ecdh-sha2-nistp256')
444
445        with self.subTest('Init sent to client'):
446            with self.assertRaises(asyncssh.ProtocolError):
447                client_conn.simulate_ecdh_init(b'')
448
449        with self.subTest('Invalid client public key'):
450            with self.assertRaises(asyncssh.ProtocolError):
451                server_conn.simulate_ecdh_init(b'')
452
453        with self.subTest('Reply sent to server'):
454            with self.assertRaises(asyncssh.ProtocolError):
455                server_conn.simulate_ecdh_reply(b'', b'', b'')
456
457        with self.subTest('Invalid server host key'):
458            with self.assertRaises(asyncssh.KeyImportError):
459                client_conn.simulate_ecdh_reply(b'', b'', b'')
460
461        with self.subTest('Invalid server public key'):
462            with self.assertRaises(asyncssh.ProtocolError):
463                host_key = server_conn.get_server_host_key()
464                client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'')
465
466        with self.subTest('Invalid signature'):
467            with self.assertRaises(asyncssh.KeyExchangeFailed):
468                host_key = server_conn.get_server_host_key()
469                server_pub = ECDH(b'nistp256').get_public()
470                client_conn.simulate_ecdh_reply(host_key.public_data,
471                                                server_pub, b'')
472
473        client_conn.close()
474        server_conn.close()
475
476    @unittest.skipUnless(curve25519_available, 'Curve25519 not available')
477    @asynctest
478    async def test_curve25519dh_errors(self):
479        """Unit test error conditions in Curve25519DH key exchange"""
480
481        client_conn, server_conn = \
482            _KexClientStub.make_pair(b'curve25519-sha256')
483
484        with self.subTest('Invalid client public key'):
485            with self.assertRaises(asyncssh.ProtocolError):
486                server_conn.simulate_ecdh_init(b'')
487
488        with self.subTest('Invalid server public key'):
489            with self.assertRaises(asyncssh.ProtocolError):
490                host_key = server_conn.get_server_host_key()
491                client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'')
492
493        with self.subTest('Invalid peer public key'):
494            with self.assertRaises(asyncssh.ProtocolError):
495                host_key = server_conn.get_server_host_key()
496                server_pub = b'\x01' + 31*b'\x00'
497                client_conn.simulate_ecdh_reply(host_key.public_data,
498                                                server_pub, b'')
499
500        with self.subTest('Invalid signature'):
501            with self.assertRaises(asyncssh.KeyExchangeFailed):
502                host_key = server_conn.get_server_host_key()
503                server_pub = Curve25519DH().get_public()
504                client_conn.simulate_ecdh_reply(host_key.public_data,
505                                                server_pub, b'')
506
507        client_conn.close()
508        server_conn.close()
509
510    @unittest.skipUnless(curve448_available, 'Curve448 not available')
511    @asynctest
512    async def test_curve448dh_errors(self):
513        """Unit test error conditions in Curve448DH key exchange"""
514
515        client_conn, server_conn = \
516            _KexClientStub.make_pair(b'curve448-sha512')
517
518        with self.subTest('Invalid client public key'):
519            with self.assertRaises(asyncssh.ProtocolError):
520                server_conn.simulate_ecdh_init(b'')
521
522        with self.subTest('Invalid server public key'):
523            with self.assertRaises(asyncssh.ProtocolError):
524                host_key = server_conn.get_server_host_key()
525                client_conn.simulate_ecdh_reply(host_key.public_data, b'', b'')
526
527        with self.subTest('Invalid peer public key'):
528            with self.assertRaises(asyncssh.ProtocolError):
529                host_key = server_conn.get_server_host_key()
530                server_pub = b'\x01' + 55*b'\x00'
531                client_conn.simulate_ecdh_reply(host_key.public_data,
532                                                server_pub, b'')
533
534        with self.subTest('Invalid signature'):
535            with self.assertRaises(asyncssh.KeyExchangeFailed):
536                host_key = server_conn.get_server_host_key()
537                server_pub = Curve448DH().get_public()
538                client_conn.simulate_ecdh_reply(host_key.public_data,
539                                                server_pub, b'')
540
541        client_conn.close()
542        server_conn.close()
543
544    @asynctest
545    async def test_rsa_errors(self):
546        """Unit test error conditions in RSA key exchange"""
547
548        client_conn, server_conn = \
549            _KexClientStub.make_pair(b'rsa2048-sha256')
550
551        with self.subTest('Pubkey sent to server'):
552            with self.assertRaises(asyncssh.ProtocolError):
553                server_conn.simulate_rsa_pubkey(b'', b'')
554
555        with self.subTest('Secret sent to client'):
556            with self.assertRaises(asyncssh.ProtocolError):
557                client_conn.simulate_rsa_secret(b'')
558
559        with self.subTest('Done sent to server'):
560            with self.assertRaises(asyncssh.ProtocolError):
561                server_conn.simulate_rsa_done(b'')
562
563        with self.subTest('Invalid transient public key'):
564            with self.assertRaises(asyncssh.ProtocolError):
565                client_conn.simulate_rsa_pubkey(b'', b'')
566
567        with self.subTest('Invalid encrypted secret'):
568            with self.assertRaises(asyncssh.KeyExchangeFailed):
569                server_conn.start()
570                server_conn.simulate_rsa_secret(b'')
571
572        with self.subTest('Invalid signature'):
573            with self.assertRaises(asyncssh.KeyExchangeFailed):
574                host_key = server_conn.get_server_host_key()
575                trans_key = asyncssh.generate_private_key('ssh-rsa', 2048)
576                client_conn.simulate_rsa_pubkey(host_key.public_data,
577                                                trans_key.public_data)
578                client_conn.simulate_rsa_done(b'')
579
580        client_conn.close()
581        server_conn.close()
582