1# Copyright (c) 2019 by Ron Frederick <ronf@timeheart.net> and others.
2#
3# This program and the accompanying materials are made available under
4# the terms of the Eclipse Public License v2.0 which accompanies this
5# distribution and is available at:
6#
7#     http://www.eclipse.org/legal/epl-2.0/
8#
9# This program may also be made available under the following secondary
10# licenses when the conditions for such availability set forth in the
11# Eclipse Public License v2.0 are satisfied:
12#
13#    GNU General Public License, Version 2.0, or any later versions of
14#    that license
15#
16# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
17#
18# Contributors:
19#     Ron Frederick - initial implementation, API, and documentation
20
21"""Unit tests for AsyncSSH subprocess API"""
22
23import asyncio
24from signal import SIGINT
25
26import asyncssh
27
28from .server import Server, ServerTestCase
29from .util import asynctest, echo
30
31class _SubprocessProtocol(asyncssh.SSHSubprocessProtocol):
32    """Unit test SSH subprocess protocol"""
33
34    def __init__(self):
35        self._chan = None
36
37        self.recv_buf = {1: [], 2: []}
38        self.exc = {1: None, 2: None}
39
40    def pipe_connection_lost(self, fd, exc):
41        """Handle remote process close"""
42
43        self.exc[fd] = exc
44
45    def pipe_data_received(self, fd, data):
46        """Handle data from the remote process"""
47
48        self.recv_buf[fd].append(data)
49
50
51async def _create_subprocess(conn, command=None, **kwargs):
52    """Create a client subprocess"""
53
54    return await conn.create_subprocess(_SubprocessProtocol, command, **kwargs)
55
56
57class _SubprocessServer(Server):
58    """Server for testing the AsyncSSH subprocess API"""
59
60    def begin_auth(self, username):
61        """Handle client authentication request"""
62
63        return False
64
65    def session_requested(self):
66        """Handle a request to create a new session"""
67
68        return self._begin_session
69
70    async def _begin_session(self, stdin, stdout, stderr):
71        """Begin processing a new session"""
72
73        # pylint: disable=no-self-use
74
75        action = stdin.channel.get_command()
76
77        if not action:
78            action = 'echo'
79
80        if action == 'exit_status':
81            stdout.channel.exit(1)
82        elif action == 'signal':
83            try:
84                await stdin.readline()
85            except asyncssh.SignalReceived as exc:
86                stdout.channel.exit_with_signal(exc.signal)
87        else:
88            await echo(stdin, stdout, stderr)
89
90class _TestSubprocess(ServerTestCase):
91    """Unit tests for AsyncSSH subprocess API"""
92
93    @classmethod
94    async def start_server(cls):
95        """Start an SSH server for the tests to use"""
96
97        return (await cls.create_server(
98            _SubprocessServer, authorized_client_keys='authorized_keys'))
99
100    async def _check_subprocess(self, conn, command=None, *,
101                                encoding=None, **kwargs):
102        """Start a subprocess and test if an input line is echoed back"""
103
104        transport, protocol = await _create_subprocess(conn, command,
105                                                       encoding=encoding,
106                                                       *kwargs)
107
108        data = str(id(self))
109
110        if encoding is None:
111            data = data.encode('ascii')
112
113        stdin = transport.get_pipe_transport(0)
114
115        self.assertTrue(stdin.can_write_eof())
116
117        stdin.writelines([data])
118
119        self.assertFalse(transport.is_closing())
120        stdin.write_eof()
121        self.assertTrue(transport.is_closing())
122
123        await transport.wait_closed()
124
125        sep = '' if encoding else b''
126
127        for buf in protocol.recv_buf.values():
128            self.assertEqual(sep.join([data]), sep.join(buf))
129
130        transport.close()
131
132    @asynctest
133    async def test_shell(self):
134        """Test starting a shell"""
135
136        async with self.connect() as conn:
137            await self._check_subprocess(conn)
138
139    @asynctest
140    async def test_exec(self):
141        """Test execution of a remote command"""
142
143        async with self.connect() as conn:
144            await self._check_subprocess(conn, 'echo')
145
146    @asynctest
147    async def test_encoding(self):
148        """Test setting encoding"""
149
150        async with self.connect() as conn:
151            await self._check_subprocess(conn, 'echo', encoding='ascii')
152
153    @asynctest
154    async def test_input(self):
155        """Test providing input when creating a subprocess"""
156
157        data = str(id(self)).encode('ascii')
158
159        async with self.connect() as conn:
160            transport, protocol = await _create_subprocess(conn, input=data)
161
162            await transport.wait_closed()
163
164        for buf in protocol.recv_buf.values():
165            self.assertEqual(b''.join(buf), data)
166
167    @asynctest
168    async def test_redirect_stderr(self):
169        """Test redirecting stderr to file"""
170
171        data = str(id(self)).encode('ascii')
172
173        async with self.connect() as conn:
174            transport, protocol = await _create_subprocess(conn,
175                                                           stderr='stderr')
176
177            stdin = transport.get_pipe_transport(0)
178            stdin.write(data)
179            stdin.write_eof()
180
181            await transport.wait_closed()
182
183        with open('stderr', 'rb') as f:
184            stderr_data = f.read()
185
186        self.assertEqual(b''.join(protocol.recv_buf[1]), data)
187        self.assertEqual(b''.join(protocol.recv_buf[2]), b'')
188        self.assertEqual(stderr_data, data)
189
190    @asynctest
191    async def test_close(self):
192        """Test closing transport"""
193
194        async with self.connect() as conn:
195            transport, protocol = await _create_subprocess(conn)
196
197            transport.close()
198
199        for buf in protocol.recv_buf.values():
200            self.assertEqual(b''.join(buf), b'')
201
202    @asynctest
203    async def test_exit_status(self):
204        """Test reading exit status"""
205
206        async with self.connect() as conn:
207            transport, protocol = await _create_subprocess(conn, 'exit_status')
208
209            await transport.wait_closed()
210
211        for buf in protocol.recv_buf.values():
212            self.assertEqual(b''.join(buf), b'')
213
214        self.assertEqual(transport.get_returncode(), 1)
215
216    @asynctest
217    async def test_stdin_abort(self):
218        """Test abort on stdin"""
219
220        async with self.connect() as conn:
221            transport, protocol = await _create_subprocess(conn)
222
223            stdin = transport.get_pipe_transport(0)
224            stdin.abort()
225
226        for buf in protocol.recv_buf.values():
227            self.assertEqual(b''.join(buf), b'')
228
229    @asynctest
230    async def test_stdin_close(self):
231        """Test closing stdin"""
232
233        async with self.connect() as conn:
234            transport, protocol = await _create_subprocess(conn)
235
236            stdin = transport.get_pipe_transport(0)
237            stdin.close()
238
239        for buf in protocol.recv_buf.values():
240            self.assertEqual(b''.join(buf), b'')
241
242    @asynctest
243    async def test_read_pause(self):
244        """Test read pause"""
245
246        async with self.connect() as conn:
247            transport, protocol = await _create_subprocess(conn)
248
249            stdin = transport.get_pipe_transport(0)
250            stdout = transport.get_pipe_transport(1)
251
252            stdout.pause_reading()
253            stdin.write(b'\n')
254            await asyncio.sleep(0.1)
255
256            for buf in protocol.recv_buf.values():
257                self.assertEqual(b''.join(buf), b'')
258
259            stdout.resume_reading()
260
261            for buf in protocol.recv_buf.values():
262                self.assertEqual(b''.join(buf), b'\n')
263
264            stdin.close()
265
266    @asynctest
267    async def test_signal(self):
268        """Test sending a signal"""
269
270        async with self.connect() as conn:
271            transport, _ = await _create_subprocess(conn, 'signal')
272
273            transport.send_signal(SIGINT)
274
275            await transport.wait_closed()
276
277            self.assertEqual(transport.get_returncode(), -SIGINT)
278
279    @asynctest
280    async def test_misc(self):
281        """Test other transport and pipe methods"""
282
283        async with self.connect() as conn:
284            transport, _ = await _create_subprocess(conn)
285
286            self.assertEqual(transport.get_pid(), None)
287
288            stdin = transport.get_pipe_transport(0)
289
290            self.assertEqual(transport.get_extra_info('socket'),
291                             stdin.get_extra_info('socket'))
292
293            stdin.set_write_buffer_limits()
294            self.assertEqual(stdin.get_write_buffer_size(), 0)
295
296            stdin.close()
297