1# Copyright (c) 2016-2020 by Ron Frederick <ronf@timeheart.net> and others.
2#
3# This program and the accompanying materials are made available under
4# the terms of the Eclipse Public License v2.0 which accompanies this
5# distribution and is available at:
6#
7#     http://www.eclipse.org/legal/epl-2.0/
8#
9# This program may also be made available under the following secondary
10# licenses when the conditions for such availability set forth in the
11# Eclipse Public License v2.0 are satisfied:
12#
13#    GNU General Public License, Version 2.0, or any later versions of
14#    that license
15#
16# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
17#
18# Contributors:
19#     Ron Frederick - initial implementation, API, and documentation
20
21"""Unit tests for AsyncSSH ssh-agent client"""
22
23import asyncio
24import functools
25import os
26from pathlib import Path
27import signal
28import subprocess
29import unittest
30
31import asyncssh
32
33from asyncssh.agent import SSH_AGENT_SUCCESS, SSH_AGENT_FAILURE
34from asyncssh.agent import SSH_AGENT_IDENTITIES_ANSWER
35from asyncssh.crypto import ed25519_available
36from asyncssh.packet import Byte, String, UInt32
37
38from .sk_stub import sk_available, patch_sk
39from .util import AsyncTestCase, asynctest, run
40
41
42def agent_test(func):
43    """Decorator for running SSH agent tests"""
44
45    @asynctest
46    @functools.wraps(func)
47    async def agent_wrapper(self):
48        """Run a test after connecting to an SSH agent"""
49
50        async with asyncssh.connect_agent() as agent:
51            await agent.remove_all()
52            await func(self, agent)
53
54    return agent_wrapper
55
56
57class _Agent:
58    """Mock SSH agent for testing error cases"""
59
60    def __init__(self, response):
61        self._response = b'' if response is None else String(response)
62        self._path = None
63        self._server = None
64
65    async def start(self, path):
66        """Start a new mock SSH agent"""
67
68        self._path = path
69
70        # pylint doesn't think start_unix_server exists
71        # pylint: disable=no-member
72        self._server = \
73            await asyncio.start_unix_server(self.process_request, path)
74
75    async def process_request(self, reader, writer):
76        """Process a request sent to the mock SSH agent"""
77
78        await reader.readexactly(4)
79        writer.write(self._response)
80        writer.close()
81
82    async def stop(self):
83        """Shut down the mock SSH agent"""
84
85        self._server.close()
86        await self._server.wait_closed()
87
88        os.remove(self._path)
89
90
91class _TestAgent(AsyncTestCase):
92    """Unit tests for AsyncSSH API"""
93
94    _agent_pid = None
95    _public_keys = {}
96
97    @staticmethod
98    def set_askpass(status):
99        """Set return status for ssh-askpass"""
100
101        with open('ssh-askpass', 'w') as f:
102            f.write('#!/bin/sh\nexit %d\n' % status)
103            os.chmod('ssh-askpass', 0o755)
104
105    # Pylint doesn't like mixed case method names, but this was chosen to
106    # match the convention used in the unittest module.
107
108    # pylint: disable=invalid-name
109
110    @classmethod
111    async def asyncSetUpClass(cls):
112        """Set up keys and an SSH server for the tests to use"""
113
114        os.environ['DISPLAY'] = ' '
115        os.environ['HOME'] = '.'
116        os.environ['SSH_ASKPASS'] = os.path.join(os.getcwd(), 'ssh-askpass')
117
118        try:
119            output = run('ssh-agent -a agent 2>/dev/null')
120        except subprocess.CalledProcessError: # pragma: no cover
121            return
122
123        cls._agent_pid = int(output.splitlines()[2].split()[3][:-1])
124        os.environ['SSH_AUTH_SOCK'] = 'agent'
125
126    @classmethod
127    async def asyncTearDownClass(cls):
128        """Shut down agents"""
129
130        if cls._agent_pid: # pragma: no branch
131            os.kill(cls._agent_pid, signal.SIGTERM)
132
133    def setUp(self):
134        """Skip unit tests if we couldn't start an agent"""
135
136        if not self._agent_pid: # pragma: no cover
137            self.skipTest('ssh-agent not available')
138
139    # pylint: enable=invalid-name
140
141    @agent_test
142    async def test_connection(self, agent):
143        """Test opening a connection to the agent"""
144
145        self.assertIsNotNone(agent)
146
147    @asynctest
148    async def test_connection_failed(self):
149        """Test failure in opening a connection to the agent"""
150
151        with self.assertRaises(OSError):
152            await asyncssh.connect_agent('xxx')
153
154    @asynctest
155    async def test_no_auth_sock(self):
156        """Test failure when no auth sock is set"""
157
158        del os.environ['SSH_AUTH_SOCK']
159
160        with self.assertRaises(OSError):
161            await asyncssh.connect_agent()
162
163        os.environ['SSH_AUTH_SOCK'] = 'agent'
164
165    @agent_test
166    async def test_get_keys(self, agent):
167        """Test getting keys from the agent"""
168
169        keys = await agent.get_keys()
170        self.assertEqual(len(keys), len(self._public_keys))
171
172    @agent_test
173    async def test_sign(self, agent):
174        """Test signing a block of data using the agent"""
175
176        algs = ['ssh-dss', 'ssh-rsa', 'ecdsa-sha2-nistp256']
177
178        if ed25519_available: # pragma: no branch
179            algs.append('ssh-ed25519')
180
181        for alg_name in algs:
182            key = asyncssh.generate_private_key(alg_name)
183            pubkey = key.convert_to_public()
184            cert = key.generate_user_certificate(key, 'name')
185
186            await agent.add_keys([(key, cert)])
187            agent_keys = await agent.get_keys()
188
189            for agent_key in agent_keys:
190                sig = await agent_key.sign(b'test')
191                self.assertTrue(pubkey.verify(b'test', sig))
192
193            await agent.remove_keys(agent_keys)
194
195    @agent_test
196    async def test_set_certificate(self, agent):
197        """Test setting certificate on an existing keypair"""
198
199        key = asyncssh.generate_private_key('ssh-rsa')
200        cert = key.generate_user_certificate(key, 'name')
201
202        key2 = asyncssh.generate_private_key('ssh-rsa')
203        cert2 = key.generate_user_certificate(key2, 'name')
204
205        await agent.add_keys([key])
206        agent_key = (await agent.get_keys())[0]
207
208        agent_key.set_certificate(cert)
209        self.assertEqual(agent_key.public_data, cert.public_data)
210
211        with self.assertRaises(ValueError):
212            asyncssh.load_keypairs([(agent_key, cert2)])
213
214        agent_key = (await agent.get_keys())[0]
215        agent_key = asyncssh.load_keypairs([(agent_key, cert)])[0]
216        self.assertEqual(agent_key.public_data, cert.public_data)
217
218        with self.assertRaises(ValueError):
219            asyncssh.load_keypairs([(agent_key, cert2)])
220
221    @agent_test
222    async def test_reconnect(self, agent):
223        """Test reconnecting to the agent after closing it"""
224
225        key = asyncssh.generate_private_key('ssh-rsa')
226        pubkey = key.convert_to_public()
227
228        async with agent:
229            await agent.add_keys([key])
230            agent_keys = await agent.get_keys()
231
232        for agent_key in agent_keys:
233            sig = await agent_key.sign(b'test')
234            self.assertTrue(pubkey.verify(b'test', sig))
235
236    @agent_test
237    async def test_add_remove_keys(self, agent):
238        """Test adding and removing keys"""
239
240        await agent.add_keys()
241        agent_keys = await agent.get_keys()
242        self.assertEqual(len(agent_keys), 0)
243
244        key = asyncssh.generate_private_key('ssh-rsa')
245        await agent.add_keys([key])
246        agent_keys = await agent.get_keys()
247        self.assertEqual(len(agent_keys), 1)
248
249        await agent.remove_keys(agent_keys)
250        agent_keys = await agent.get_keys()
251        self.assertEqual(len(agent_keys), 0)
252
253        await agent.add_keys([key])
254        agent_keys = await agent.get_keys()
255        self.assertEqual(len(agent_keys), 1)
256
257        await agent_keys[0].remove()
258        agent_keys = await agent.get_keys()
259        self.assertEqual(len(agent_keys), 0)
260
261        await agent.add_keys([key], lifetime=1)
262        agent_keys = await agent.get_keys()
263        self.assertEqual(len(agent_keys), 1)
264        await asyncio.sleep(2)
265
266        agent_keys = await agent.get_keys()
267        self.assertEqual(len(agent_keys), 0)
268
269    @agent_test
270    async def test_add_keys_failure(self, agent):
271        """Test getting keys from the agent"""
272
273        os.mkdir('.ssh', 0o700)
274        key = asyncssh.generate_private_key('ssh-rsa')
275        key.write_private_key(Path('.ssh', 'id_rsa'))
276
277        try:
278            mock_agent = _Agent(Byte(SSH_AGENT_FAILURE))
279            await mock_agent.start('mock_agent')
280
281            async with asyncssh.connect_agent('mock_agent') as agent:
282                async with agent:
283                    await agent.add_keys()
284
285                async with agent:
286                    with self.assertRaises(ValueError):
287                        await agent.add_keys([key])
288        finally:
289            await mock_agent.stop()
290            os.remove(os.path.join('.ssh', 'id_rsa'))
291            os.rmdir('.ssh')
292
293    @unittest.skipUnless(sk_available, 'security key support not available')
294    @patch_sk([2])
295    @asynctest
296    async def test_add_sk_keys(self):
297        """Test adding U2F security keys"""
298
299        key = asyncssh.generate_private_key(
300            'sk-ecdsa-sha2-nistp256@openssh.com')
301        cert = key.generate_user_certificate(key, 'test')
302
303        mock_agent = _Agent(Byte(SSH_AGENT_SUCCESS))
304        await mock_agent.start('mock_agent')
305
306        async with asyncssh.connect_agent('mock_agent') as agent:
307            for keypair in asyncssh.load_keypairs([key, (key, cert)]):
308                async with agent:
309                    self.assertIsNone(await agent.add_keys([keypair]))
310
311            async with agent:
312                with self.assertRaises(asyncssh.KeyExportError):
313                    await agent.add_keys([key.convert_to_public()])
314
315        await mock_agent.stop()
316
317    @unittest.skipUnless(sk_available, 'security key support not available')
318    @patch_sk([2])
319    @asynctest
320    async def test_get_sk_keys(self):
321        """Test getting U2F security keys"""
322
323        key = asyncssh.generate_private_key(
324            'sk-ecdsa-sha2-nistp256@openssh.com')
325        cert = key.generate_user_certificate(key, 'test')
326
327        mock_agent = _Agent(Byte(SSH_AGENT_IDENTITIES_ANSWER) + UInt32(2) +
328                            String(key.public_data) + String('') +
329                            String(cert.public_data) + String(''))
330
331        await mock_agent.start('mock_agent')
332
333        async with asyncssh.connect_agent('mock_agent') as agent:
334            await agent.get_keys()
335
336        await mock_agent.stop()
337
338    @asynctest
339    async def test_add_remove_smartcard_keys(self):
340        """Test adding and removing smart card keys"""
341
342        mock_agent = _Agent(Byte(SSH_AGENT_SUCCESS))
343        await mock_agent.start('mock_agent')
344
345        async with asyncssh.connect_agent('mock_agent') as agent:
346            result = await agent.add_smartcard_keys('provider')
347            self.assertIsNone(result)
348
349        await mock_agent.stop()
350
351        mock_agent = _Agent(Byte(SSH_AGENT_SUCCESS))
352        await mock_agent.start('mock_agent')
353
354        async with asyncssh.connect_agent('mock_agent') as agent:
355            result = await agent.remove_smartcard_keys('provider')
356            self.assertIsNone(result)
357
358        await mock_agent.stop()
359
360    @agent_test
361    async def test_confirm(self, agent):
362        """Test confirmation of key"""
363
364        key = asyncssh.generate_private_key('ssh-rsa')
365        pubkey = key.convert_to_public()
366
367        await agent.add_keys([key], confirm=True)
368        agent_keys = await agent.get_keys()
369
370        self.set_askpass(1)
371
372        for agent_key in agent_keys:
373            with self.assertRaises(ValueError):
374                sig = await agent_key.sign(b'test')
375
376        self.set_askpass(0)
377
378        for agent_key in agent_keys:
379            sig = await agent_key.sign(b'test')
380            self.assertTrue(pubkey.verify(b'test', sig))
381
382    @agent_test
383    async def test_lock(self, agent):
384        """Test lock and unlock"""
385
386        key = asyncssh.generate_private_key('ssh-rsa')
387        pubkey = key.convert_to_public()
388
389        await agent.add_keys([key])
390        agent_keys = await agent.get_keys()
391
392        await agent.lock('passphrase')
393
394        for agent_key in agent_keys:
395            with self.assertRaises(ValueError):
396                await agent_key.sign(b'test')
397
398        await agent.unlock('passphrase')
399
400        for agent_key in agent_keys:
401            sig = await agent_key.sign(b'test')
402            self.assertTrue(pubkey.verify(b'test', sig))
403
404    @asynctest
405    async def test_query_extensions(self):
406        """Test query of supported extensions"""
407
408        mock_agent = _Agent(Byte(SSH_AGENT_SUCCESS) + String('xxx'))
409        await mock_agent.start('mock_agent')
410
411        async with asyncssh.connect_agent('mock_agent') as agent:
412            extensions = await agent.query_extensions()
413            self.assertEqual(extensions, ['xxx'])
414
415        await mock_agent.stop()
416
417        mock_agent = _Agent(Byte(SSH_AGENT_SUCCESS) + String(b'\xff'))
418        await mock_agent.start('mock_agent')
419
420        async with asyncssh.connect_agent('mock_agent') as agent:
421            with self.assertRaises(ValueError):
422                await agent.query_extensions()
423
424        await mock_agent.stop()
425
426        mock_agent = _Agent(Byte(SSH_AGENT_FAILURE))
427        await mock_agent.start('mock_agent')
428
429        async with asyncssh.connect_agent('mock_agent') as agent:
430            extensions = await agent.query_extensions()
431            self.assertEqual(extensions, [])
432
433        await mock_agent.stop()
434
435        mock_agent = _Agent(b'\xff')
436        await mock_agent.start('mock_agent')
437
438        async with asyncssh.connect_agent('mock_agent') as agent:
439            with self.assertRaises(ValueError):
440                await agent.query_extensions()
441
442        await mock_agent.stop()
443
444    @agent_test
445    async def test_unknown_key(self, agent):
446        """Test failure when signing with an unknown key"""
447
448        key = asyncssh.generate_private_key('ssh-rsa')
449
450        with self.assertRaises(ValueError):
451            await agent.sign(key.public_data, b'test')
452
453    @agent_test
454    async def test_double_close(self, agent):
455        """Test calling close more than once on the agent"""
456
457        self.assertIsNotNone(agent)
458        agent.close()
459
460    @asynctest
461    async def test_errors(self):
462        """Test getting error responses from SSH agent"""
463
464        key = asyncssh.generate_private_key('ssh-rsa')
465        keypair = asyncssh.load_keypairs(key)[0]
466
467        for response in (None, b'', Byte(SSH_AGENT_FAILURE), b'\xff'):
468            mock_agent = _Agent(response)
469            await mock_agent.start('mock_agent')
470
471            async with asyncssh.connect_agent('mock_agent') as agent:
472                for request in (agent.get_keys(),
473                                agent.sign(b'xxx', b'test'),
474                                agent.add_keys([key]),
475                                agent.add_smartcard_keys('xxx'),
476                                agent.remove_keys([keypair]),
477                                agent.remove_smartcard_keys('xxx'),
478                                agent.remove_all(),
479                                agent.lock('passphrase'),
480                                agent.unlock('passphrase')):
481                    async with agent:
482                        with self.assertRaises(ValueError):
483                            await request
484
485            await mock_agent.stop()
486