1# Copyright (c) 2019 by Ron Frederick <ronf@timeheart.net> and others. 2# 3# This program and the accompanying materials are made available under 4# the terms of the Eclipse Public License v2.0 which accompanies this 5# distribution and is available at: 6# 7# http://www.eclipse.org/legal/epl-2.0/ 8# 9# This program may also be made available under the following secondary 10# licenses when the conditions for such availability set forth in the 11# Eclipse Public License v2.0 are satisfied: 12# 13# GNU General Public License, Version 2.0, or any later versions of 14# that license 15# 16# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later 17# 18# Contributors: 19# Ron Frederick - initial implementation, API, and documentation 20 21"""Unit tests for AsyncSSH subprocess API""" 22 23import asyncio 24from signal import SIGINT 25 26import asyncssh 27 28from .server import Server, ServerTestCase 29from .util import asynctest, echo 30 31class _SubprocessProtocol(asyncssh.SSHSubprocessProtocol): 32 """Unit test SSH subprocess protocol""" 33 34 def __init__(self): 35 self._chan = None 36 37 self.recv_buf = {1: [], 2: []} 38 self.exc = {1: None, 2: None} 39 40 def pipe_connection_lost(self, fd, exc): 41 """Handle remote process close""" 42 43 self.exc[fd] = exc 44 45 def pipe_data_received(self, fd, data): 46 """Handle data from the remote process""" 47 48 self.recv_buf[fd].append(data) 49 50 51async def _create_subprocess(conn, command=None, **kwargs): 52 """Create a client subprocess""" 53 54 return await conn.create_subprocess(_SubprocessProtocol, command, **kwargs) 55 56 57class _SubprocessServer(Server): 58 """Server for testing the AsyncSSH subprocess API""" 59 60 def begin_auth(self, username): 61 """Handle client authentication request""" 62 63 return False 64 65 def session_requested(self): 66 """Handle a request to create a new session""" 67 68 return self._begin_session 69 70 async def _begin_session(self, stdin, stdout, stderr): 71 """Begin processing a new session""" 72 73 # pylint: disable=no-self-use 74 75 action = stdin.channel.get_command() 76 77 if not action: 78 action = 'echo' 79 80 if action == 'exit_status': 81 stdout.channel.exit(1) 82 elif action == 'signal': 83 try: 84 await stdin.readline() 85 except asyncssh.SignalReceived as exc: 86 stdout.channel.exit_with_signal(exc.signal) 87 else: 88 await echo(stdin, stdout, stderr) 89 90class _TestSubprocess(ServerTestCase): 91 """Unit tests for AsyncSSH subprocess API""" 92 93 @classmethod 94 async def start_server(cls): 95 """Start an SSH server for the tests to use""" 96 97 return (await cls.create_server( 98 _SubprocessServer, authorized_client_keys='authorized_keys')) 99 100 async def _check_subprocess(self, conn, command=None, *, 101 encoding=None, **kwargs): 102 """Start a subprocess and test if an input line is echoed back""" 103 104 transport, protocol = await _create_subprocess(conn, command, 105 encoding=encoding, 106 *kwargs) 107 108 data = str(id(self)) 109 110 if encoding is None: 111 data = data.encode('ascii') 112 113 stdin = transport.get_pipe_transport(0) 114 115 self.assertTrue(stdin.can_write_eof()) 116 117 stdin.writelines([data]) 118 119 self.assertFalse(transport.is_closing()) 120 stdin.write_eof() 121 self.assertTrue(transport.is_closing()) 122 123 await transport.wait_closed() 124 125 sep = '' if encoding else b'' 126 127 for buf in protocol.recv_buf.values(): 128 self.assertEqual(sep.join([data]), sep.join(buf)) 129 130 transport.close() 131 132 @asynctest 133 async def test_shell(self): 134 """Test starting a shell""" 135 136 async with self.connect() as conn: 137 await self._check_subprocess(conn) 138 139 @asynctest 140 async def test_exec(self): 141 """Test execution of a remote command""" 142 143 async with self.connect() as conn: 144 await self._check_subprocess(conn, 'echo') 145 146 @asynctest 147 async def test_encoding(self): 148 """Test setting encoding""" 149 150 async with self.connect() as conn: 151 await self._check_subprocess(conn, 'echo', encoding='ascii') 152 153 @asynctest 154 async def test_input(self): 155 """Test providing input when creating a subprocess""" 156 157 data = str(id(self)).encode('ascii') 158 159 async with self.connect() as conn: 160 transport, protocol = await _create_subprocess(conn, input=data) 161 162 await transport.wait_closed() 163 164 for buf in protocol.recv_buf.values(): 165 self.assertEqual(b''.join(buf), data) 166 167 @asynctest 168 async def test_redirect_stderr(self): 169 """Test redirecting stderr to file""" 170 171 data = str(id(self)).encode('ascii') 172 173 async with self.connect() as conn: 174 transport, protocol = await _create_subprocess(conn, 175 stderr='stderr') 176 177 stdin = transport.get_pipe_transport(0) 178 stdin.write(data) 179 stdin.write_eof() 180 181 await transport.wait_closed() 182 183 with open('stderr', 'rb') as f: 184 stderr_data = f.read() 185 186 self.assertEqual(b''.join(protocol.recv_buf[1]), data) 187 self.assertEqual(b''.join(protocol.recv_buf[2]), b'') 188 self.assertEqual(stderr_data, data) 189 190 @asynctest 191 async def test_close(self): 192 """Test closing transport""" 193 194 async with self.connect() as conn: 195 transport, protocol = await _create_subprocess(conn) 196 197 transport.close() 198 199 for buf in protocol.recv_buf.values(): 200 self.assertEqual(b''.join(buf), b'') 201 202 @asynctest 203 async def test_exit_status(self): 204 """Test reading exit status""" 205 206 async with self.connect() as conn: 207 transport, protocol = await _create_subprocess(conn, 'exit_status') 208 209 await transport.wait_closed() 210 211 for buf in protocol.recv_buf.values(): 212 self.assertEqual(b''.join(buf), b'') 213 214 self.assertEqual(transport.get_returncode(), 1) 215 216 @asynctest 217 async def test_stdin_abort(self): 218 """Test abort on stdin""" 219 220 async with self.connect() as conn: 221 transport, protocol = await _create_subprocess(conn) 222 223 stdin = transport.get_pipe_transport(0) 224 stdin.abort() 225 226 for buf in protocol.recv_buf.values(): 227 self.assertEqual(b''.join(buf), b'') 228 229 @asynctest 230 async def test_stdin_close(self): 231 """Test closing stdin""" 232 233 async with self.connect() as conn: 234 transport, protocol = await _create_subprocess(conn) 235 236 stdin = transport.get_pipe_transport(0) 237 stdin.close() 238 239 for buf in protocol.recv_buf.values(): 240 self.assertEqual(b''.join(buf), b'') 241 242 @asynctest 243 async def test_read_pause(self): 244 """Test read pause""" 245 246 async with self.connect() as conn: 247 transport, protocol = await _create_subprocess(conn) 248 249 stdin = transport.get_pipe_transport(0) 250 stdout = transport.get_pipe_transport(1) 251 252 stdout.pause_reading() 253 stdin.write(b'\n') 254 await asyncio.sleep(0.1) 255 256 for buf in protocol.recv_buf.values(): 257 self.assertEqual(b''.join(buf), b'') 258 259 stdout.resume_reading() 260 261 for buf in protocol.recv_buf.values(): 262 self.assertEqual(b''.join(buf), b'\n') 263 264 stdin.close() 265 266 @asynctest 267 async def test_signal(self): 268 """Test sending a signal""" 269 270 async with self.connect() as conn: 271 transport, _ = await _create_subprocess(conn, 'signal') 272 273 transport.send_signal(SIGINT) 274 275 await transport.wait_closed() 276 277 self.assertEqual(transport.get_returncode(), -SIGINT) 278 279 @asynctest 280 async def test_misc(self): 281 """Test other transport and pipe methods""" 282 283 async with self.connect() as conn: 284 transport, _ = await _create_subprocess(conn) 285 286 self.assertEqual(transport.get_pid(), None) 287 288 stdin = transport.get_pipe_transport(0) 289 290 self.assertEqual(transport.get_extra_info('socket'), 291 stdin.get_extra_info('socket')) 292 293 stdin.set_write_buffer_limits() 294 self.assertEqual(stdin.get_write_buffer_size(), 0) 295 296 stdin.close() 297