1# Copyright (c) 2016-2020 by Ron Frederick <ronf@timeheart.net> and others.
2#
3# This program and the accompanying materials are made available under
4# the terms of the Eclipse Public License v2.0 which accompanies this
5# distribution and is available at:
6#
7#     http://www.eclipse.org/legal/epl-2.0/
8#
9# This program may also be made available under the following secondary
10# licenses when the conditions for such availability set forth in the
11# Eclipse Public License v2.0 are satisfied:
12#
13#    GNU General Public License, Version 2.0, or any later versions of
14#    that license
15#
16# SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-or-later
17#
18# Contributors:
19#     Ron Frederick - initial implementation, API, and documentation
20
21"""Unit tests for AsyncSSH process API"""
22
23import asyncio
24import io
25import os
26from pathlib import Path
27from signal import SIGINT
28import socket
29import sys
30import unittest
31
32import asyncssh
33
34from .server import ServerTestCase
35from .util import asynctest, echo
36
37try:
38    import aiofiles
39    _aiofiles_available = True
40except ImportError: # pragma: no cover
41    _aiofiles_available = False
42
43async def _handle_client(process):
44    """Handle a new client request"""
45
46    action = process.command or process.subsystem
47    if not action:
48        action = 'echo'
49
50    if action == 'break':
51        try:
52            await process.stdin.readline()
53        except asyncssh.BreakReceived as exc:
54            process.exit_with_signal('ABRT', False, str(exc.msec))
55    elif action == 'delay':
56        await asyncio.sleep(1)
57        await echo(process.stdin, process.stdout, process.stderr)
58    elif action == 'echo':
59        await echo(process.stdin, process.stdout, process.stderr)
60    elif action == 'exit_status':
61        process.channel.set_encoding('utf-8')
62        process.stderr.write('Exiting with status 1')
63        process.exit(1)
64    elif action == 'env':
65        process.channel.set_encoding('utf-8')
66        process.stdout.write(process.env.get('TEST', ''))
67    elif action == 'redirect_stdin':
68        await process.redirect_stdin(process.stdout)
69        await process.stdout.drain()
70    elif action == 'redirect_stdout':
71        await process.redirect_stdout(process.stdin)
72        await process.stdout.drain()
73    elif action == 'redirect_stderr':
74        await process.redirect_stderr(process.stdin)
75        await process.stderr.drain()
76    elif action == 'old_term':
77        info = str((process.get_terminal_type(), process.get_terminal_size(),
78                    process.get_terminal_mode(asyncssh.PTY_OP_OSPEED)))
79        process.channel.set_encoding('utf-8')
80        process.stdout.write(info)
81    elif action == 'term':
82        info = str((process.term_type, process.term_size,
83                    process.term_modes.get(asyncssh.PTY_OP_OSPEED),
84                    sorted(process.term_modes.items())))
85        process.channel.set_encoding('utf-8')
86        process.stdout.write(info)
87    elif action == 'term_size':
88        try:
89            await process.stdin.readline()
90        except asyncssh.TerminalSizeChanged as exc:
91            process.exit_with_signal('ABRT', False,
92                                     '%sx%s' % (exc.width, exc.height))
93    elif action == 'timeout':
94        process.channel.set_encoding('utf-8')
95        process.stdout.write('Sleeping')
96        await asyncio.sleep(1)
97    else:
98        process.exit(255)
99
100    process.close()
101    await process.wait_closed()
102
103
104class _TestProcess(ServerTestCase):
105    """Unit tests for AsyncSSH process API"""
106
107    @classmethod
108    async def start_server(cls):
109        """Start an SSH server for the tests to use"""
110
111        return await cls.create_server(process_factory=_handle_client,
112                                       encoding=None)
113
114
115class _TestProcessBasic(_TestProcess):
116    """Unit tests for AsyncSSH process basic functions"""
117
118    @asynctest
119    async def test_shell(self):
120        """Test starting a remote shell"""
121
122        data = str(id(self))
123
124        async with self.connect() as conn:
125            process = await conn.create_process(env={'TEST': 'test'})
126
127            process.stdin.write(data)
128
129            self.assertFalse(process.is_closing())
130            process.stdin.write_eof()
131            self.assertTrue(process.is_closing())
132
133            result = await process.wait()
134
135        self.assertEqual(result.env, {'TEST': 'test'})
136        self.assertEqual(result.command, None)
137        self.assertEqual(result.subsystem, None)
138        self.assertEqual(result.exit_status, None)
139        self.assertEqual(result.exit_signal, None)
140        self.assertEqual(result.returncode, None)
141        self.assertEqual(result.stdout, data)
142        self.assertEqual(result.stderr, data)
143
144    @asynctest
145    async def test_command(self):
146        """Test executing a remote command"""
147
148        data = str(id(self))
149
150        async with self.connect() as conn:
151            process = await conn.create_process('echo')
152
153            process.stdin.write(data)
154            process.stdin.write_eof()
155
156            result = await process.wait()
157
158        self.assertEqual(result.command, 'echo')
159        self.assertEqual(result.subsystem, None)
160        self.assertEqual(result.stdout, data)
161        self.assertEqual(result.stderr, data)
162
163    @asynctest
164    async def test_subsystem(self):
165        """Test starting a remote subsystem"""
166
167        data = str(id(self))
168
169        async with self.connect() as conn:
170            process = await conn.create_process(subsystem='echo')
171
172            process.stdin.write(data)
173            process.stdin.write_eof()
174
175            result = await process.wait()
176
177        self.assertEqual(result.command, None)
178        self.assertEqual(result.subsystem, 'echo')
179        self.assertEqual(result.stdout, data)
180        self.assertEqual(result.stderr, data)
181
182    @asynctest
183    async def test_communicate(self):
184        """Test communicate"""
185
186        data = str(id(self))
187
188        async with self.connect() as conn:
189            async with conn.create_process() as process:
190                stdout_data, stderr_data = await process.communicate(data)
191
192        self.assertEqual(stdout_data, data)
193        self.assertEqual(stderr_data, data)
194
195    @asynctest
196    async def test_communicate_paused(self):
197        """Test communicate when reading is already paused"""
198
199        data = 4*1024*1024*'*'
200
201        async with self.connect() as conn:
202            async with conn.create_process(input=data) as process:
203                await asyncio.sleep(1)
204                stdout_data, stderr_data = await process.communicate()
205
206        self.assertEqual(stdout_data, data)
207        self.assertEqual(stderr_data, data)
208
209    @asynctest
210    async def test_env(self):
211        """Test sending environment"""
212
213        async with self.connect() as conn:
214            process = await conn.create_process('env', env={'TEST': 'test'})
215            result = await process.wait()
216
217        self.assertEqual(result.stdout, 'test')
218
219    @asynctest
220    async def test_old_terminal_info(self):
221        """Test setting and retrieving terminal information with old API"""
222
223        modes = {asyncssh.PTY_OP_OSPEED: 9600}
224
225        async with self.connect() as conn:
226            process = await conn.create_process('old_term', term_type='ansi',
227                                                term_size=(80, 24),
228                                                term_modes=modes)
229            result = await process.wait()
230
231        self.assertEqual(result.stdout, "('ansi', (80, 24, 0, 0), 9600)")
232
233    @asynctest
234    async def test_terminal_info(self):
235        """Test setting and retrieving terminal information"""
236
237        modes = {asyncssh.PTY_OP_ISPEED: 9600, asyncssh.PTY_OP_OSPEED: 9600}
238
239        async with self.connect() as conn:
240            process = await conn.create_process('term', term_type='ansi',
241                                                term_size=(80, 24),
242                                                term_modes=modes)
243            result = await process.wait()
244
245        self.assertEqual(result.stdout, "('ansi', (80, 24, 0, 0), 9600, "
246                                        "[(128, 9600), (129, 9600)])")
247
248    @asynctest
249    async def test_change_terminal_size(self):
250        """Test changing terminal size"""
251
252        async with self.connect() as conn:
253            process = await conn.create_process('term_size', term_type='ansi')
254            process.change_terminal_size(80, 24)
255            result = await process.wait()
256
257        self.assertEqual(result.exit_signal[2], '80x24')
258
259    @asynctest
260    async def test_break(self):
261        """Test sending a break"""
262
263        async with self.connect() as conn:
264            process = await conn.create_process('break')
265            process.send_break(1000)
266            result = await process.wait()
267
268        self.assertEqual(result.exit_signal[2], '1000')
269
270    @asynctest
271    async def test_signal(self):
272        """Test sending a signal"""
273
274        async with self.connect() as conn:
275            process = await conn.create_process()
276            process.send_signal('INT')
277            result = await process.wait()
278
279        self.assertEqual(result.exit_signal[0], 'INT')
280        self.assertEqual(result.returncode, -SIGINT)
281
282    @asynctest
283    async def test_numeric_signal(self):
284        """Test sending a signal using a numeric value"""
285
286        async with self.connect() as conn:
287            process = await conn.create_process()
288            process.send_signal(SIGINT)
289            result = await process.wait()
290
291        self.assertEqual(result.exit_signal[0], 'INT')
292        self.assertEqual(result.returncode, -SIGINT)
293
294    @asynctest
295    async def test_terminate(self):
296        """Test sending a terminate signal"""
297
298        async with self.connect() as conn:
299            process = await conn.create_process()
300            process.terminate()
301            result = await process.wait()
302
303        self.assertEqual(result.exit_signal[0], 'TERM')
304
305    @asynctest
306    async def test_kill(self):
307        """Test sending a kill signal"""
308
309        async with self.connect() as conn:
310            process = await conn.create_process()
311            process.kill()
312            result = await process.wait()
313
314        self.assertEqual(result.exit_signal[0], 'KILL')
315
316    @asynctest
317    async def test_exit_status(self):
318        """Test checking exit status"""
319
320        async with self.connect() as conn:
321            result = await conn.run('exit_status')
322
323        self.assertEqual(result.exit_status, 1)
324        self.assertEqual(result.returncode, 1)
325        self.assertEqual(result.stdout, '')
326        self.assertEqual(result.stderr, 'Exiting with status 1')
327
328    @asynctest
329    async def test_raise_on_exit_status(self):
330        """Test raising an exception on non-zero exit status"""
331
332        async with self.connect() as conn:
333            with self.assertRaises(asyncssh.ProcessError) as exc:
334                await conn.run('exit_status', env={'TEST': 'test'}, check=True)
335
336        self.assertEqual(exc.exception.env, {'TEST': 'test'})
337        self.assertEqual(exc.exception.command, 'exit_status')
338        self.assertEqual(exc.exception.subsystem, None)
339        self.assertEqual(exc.exception.exit_status, 1)
340        self.assertEqual(exc.exception.reason,
341                         'Process exited with non-zero exit status 1')
342        self.assertEqual(exc.exception.returncode, 1)
343
344    @asynctest
345    async def test_raise_on_timeout(self):
346        """Test raising an exception on timeout"""
347
348        async with self.connect() as conn:
349            with self.assertRaises(asyncssh.ProcessError) as exc:
350                await conn.run('timeout', timeout=0.1)
351
352        self.assertEqual(exc.exception.command, 'timeout')
353        self.assertEqual(exc.exception.reason, '')
354        self.assertEqual(exc.exception.stdout, 'Sleeping')
355
356    @asynctest
357    async def test_exit_signal(self):
358        """Test checking exit signal"""
359
360        async with self.connect() as conn:
361            process = await conn.create_process()
362            process.send_signal('INT')
363            result = await process.wait()
364
365        self.assertEqual(result.exit_status, -1)
366        self.assertEqual(result.exit_signal[0], 'INT')
367        self.assertEqual(result.returncode, -SIGINT)
368
369    @asynctest
370    async def test_raise_on_exit_signal(self):
371        """Test raising an exception on exit signal"""
372
373        async with self.connect() as conn:
374            process = await conn.create_process()
375
376            with self.assertRaises(asyncssh.ProcessError) as exc:
377                process.send_signal('INT')
378                await process.wait(check=True)
379
380        self.assertEqual(exc.exception.exit_status, -1)
381        self.assertEqual(exc.exception.exit_signal[0], 'INT')
382        self.assertEqual(exc.exception.reason,
383                         'Process exited with signal INT')
384        self.assertEqual(exc.exception.returncode, -SIGINT)
385
386    @asynctest
387    async def test_split_unicode(self):
388        """Test Unicode split across blocks"""
389
390        data = '\u2000test\u2000'
391
392        with open('stdin', 'w', encoding='utf-8') as file:
393            file.write(data)
394
395        async with self.connect() as conn:
396            result = await conn.run('echo', stdin='stdin', bufsize=2)
397
398        self.assertEqual(result.stdout, data)
399
400    @asynctest
401    async def test_invalid_unicode(self):
402        """Test invalid Unicode data"""
403
404        data = b'\xfftest'
405
406        with open('stdin', 'wb') as file:
407            file.write(data)
408
409        async with self.connect() as conn:
410            with self.assertRaises(asyncssh.ProtocolError):
411                await conn.run('echo', stdin='stdin')
412
413    @asynctest
414    async def test_ignoring_invalid_unicode(self):
415        """Test ignoring invalid Unicode data"""
416
417        data = b'\xfftest'
418
419        with open('stdin', 'wb') as file:
420            file.write(data)
421
422        async with self.connect() as conn:
423            await conn.run('echo', stdin='stdin',
424                           encoding='utf-8', errors='ignore')
425
426    @asynctest
427    async def test_incomplete_unicode(self):
428        """Test incomplete Unicode data"""
429
430        data = '\u2000'.encode('utf-8')[:2]
431
432        with open('stdin', 'wb') as file:
433            file.write(data)
434
435        async with self.connect() as conn:
436            with self.assertRaises(asyncssh.ProtocolError):
437                await conn.run('echo', stdin='stdin')
438
439    @asynctest
440    async def test_disconnect(self):
441        """Test collecting output from a disconnected channel"""
442
443        data = str(id(self))
444
445        async with self.connect() as conn:
446            process = await conn.create_process()
447
448            process.stdin.write(data)
449            process.send_signal('ABRT')
450
451            result = await process.wait()
452
453        self.assertEqual(result.stdout, data)
454        self.assertEqual(result.stderr, data)
455
456    @asynctest
457    async def test_get_extra_info(self):
458        """Test get_extra_info on streams"""
459
460        async with self.connect() as conn:
461            process = await conn.create_process()
462            self.assertEqual(process.get_extra_info('connection'), conn)
463            process.stdin.write_eof()
464
465            await process.wait()
466
467    @asynctest
468    async def test_unknown_action(self):
469        """Test unknown action"""
470
471        async with self.connect() as conn:
472            result = await conn.run('unknown')
473
474        self.assertEqual(result.exit_status, 255)
475
476
477class _TestProcessRedirection(_TestProcess):
478    """Unit tests for AsyncSSH process I/O redirection"""
479
480    @asynctest
481    async def test_input(self):
482        """Test with input from a string"""
483
484        data = str(id(self))
485
486        async with self.connect() as conn:
487            result = await conn.run('echo', input=data)
488
489        self.assertEqual(result.stdout, data)
490        self.assertEqual(result.stderr, data)
491
492    @asynctest
493    async def test_stdin_devnull(self):
494        """Test with stdin redirected to DEVNULL"""
495
496        async with self.connect() as conn:
497            result = await conn.run('echo', stdin=asyncssh.DEVNULL)
498
499        self.assertEqual(result.stdout, '')
500        self.assertEqual(result.stderr, '')
501
502    @asynctest
503    async def test_stdin_file(self):
504        """Test with stdin redirected to a file"""
505
506        data = str(id(self))
507
508        with open('stdin', 'w') as file:
509            file.write(data)
510
511        async with self.connect() as conn:
512            result = await conn.run('echo', stdin='stdin')
513
514        self.assertEqual(result.stdout, data)
515        self.assertEqual(result.stderr, data)
516
517    @asynctest
518    async def test_stdin_binary_file(self):
519        """Test with stdin redirected to a file in binary mode"""
520
521        data = str(id(self)).encode() + b'\xff'
522
523        with open('stdin', 'wb') as file:
524            file.write(data)
525
526        async with self.connect() as conn:
527            result = await conn.run('echo', stdin='stdin', encoding=None)
528
529        self.assertEqual(result.stdout, data)
530        self.assertEqual(result.stderr, data)
531
532    @asynctest
533    async def test_stdin_pathlib(self):
534        """Test with stdin redirected to a file name specified by pathlib"""
535
536        data = str(id(self))
537
538        with open('stdin', 'w') as file:
539            file.write(data)
540
541        async with self.connect() as conn:
542            result = await conn.run('echo', stdin=Path('stdin'))
543
544        self.assertEqual(result.stdout, data)
545        self.assertEqual(result.stderr, data)
546
547    @asynctest
548    async def test_stdin_open_file(self):
549        """Test with stdin redirected to an open file"""
550
551        data = str(id(self))
552
553        with open('stdin', 'w') as file:
554            file.write(data)
555
556        file = open('stdin', 'r')
557
558        async with self.connect() as conn:
559            result = await conn.run('echo', stdin=file)
560
561        self.assertEqual(result.stdout, data)
562        self.assertEqual(result.stderr, data)
563
564    @asynctest
565    async def test_stdin_open_binary_file(self):
566        """Test with stdin redirected to an open file in binary mode"""
567
568        data = str(id(self)).encode() + b'\xff'
569
570        with open('stdin', 'wb') as file:
571            file.write(data)
572
573        file = open('stdin', 'rb')
574
575        async with self.connect() as conn:
576            result = await conn.run('echo', stdin=file, encoding=None)
577
578        self.assertEqual(result.stdout, data)
579        self.assertEqual(result.stderr, data)
580
581    @asynctest
582    async def test_stdin_stringio(self):
583        """Test with stdin redirected to a StringIO object"""
584
585        data = str(id(self))
586
587        with open('stdin', 'w') as file:
588            file.write(data)
589
590        file = io.StringIO(data)
591
592        async with self.connect() as conn:
593            result = await conn.run('echo', stdin=file)
594
595        self.assertEqual(result.stdout, data)
596        self.assertEqual(result.stderr, data)
597
598    @asynctest
599    async def test_stdin_bytesio(self):
600        """Test with stdin redirected to a BytesIO object"""
601
602        data = str(id(self))
603
604        with open('stdin', 'w') as file:
605            file.write(data)
606
607        file = io.BytesIO(data.encode('ascii'))
608
609        async with self.connect() as conn:
610            result = await conn.run('echo', stdin=file)
611
612        self.assertEqual(result.stdout, data)
613        self.assertEqual(result.stderr, data)
614
615    @asynctest
616    async def test_stdin_process(self):
617        """Test with stdin redirected to another SSH process"""
618
619        data = str(id(self))
620
621        async with self.connect() as conn:
622            proc1 = await conn.create_process(input=data)
623            proc2 = await conn.create_process(stdin=proc1.stdout)
624            result = await proc2.wait()
625
626        self.assertEqual(result.stdout, data)
627        self.assertEqual(result.stderr, data)
628
629    @unittest.skipIf(sys.platform == 'win32',
630                     'skip asyncio.subprocess tests on Windows')
631    @asynctest
632    async def test_stdin_stream(self):
633        """Test with stdin redirected to an asyncio stream"""
634
635        data = 4*1024*1024*'*'
636
637        async with self.connect() as conn:
638            proc1 = await asyncio.create_subprocess_shell(
639                'cat', stdin=asyncio.subprocess.PIPE,
640                stdout=asyncio.subprocess.PIPE)
641            proc1.stdin.write(data.encode('ascii'))
642            proc1.stdin.write_eof()
643
644            proc2 = await conn.create_process('delay', stdin=proc1.stdout)
645            result = await proc2.wait()
646
647        self.assertEqual(result.stdout, data)
648        self.assertEqual(result.stderr, data)
649
650    @asynctest
651    async def test_stdout_devnull(self):
652        """Test with stdout redirected to DEVNULL"""
653
654        data = str(id(self))
655
656        async with self.connect() as conn:
657            result = await conn.run('echo', input=data,
658                                    stdout=asyncssh.DEVNULL)
659
660        self.assertEqual(result.stdout, '')
661        self.assertEqual(result.stderr, data)
662
663    @asynctest
664    async def test_stdout_file(self):
665        """Test with stdout redirected to a file"""
666
667        data = str(id(self))
668
669        async with self.connect() as conn:
670            result = await conn.run('echo', input=data, stdout='stdout')
671
672        with open('stdout', 'r') as file:
673            stdout_data = file.read()
674
675        self.assertEqual(stdout_data, data)
676        self.assertEqual(result.stdout, '')
677        self.assertEqual(result.stderr, data)
678
679    @asynctest
680    async def test_stdout_binary_file(self):
681        """Test with stdout redirected to a file in binary mode"""
682
683        data = str(id(self)).encode() + b'\xff'
684
685        async with self.connect() as conn:
686            result = await conn.run('echo', input=data, stdout='stdout',
687                                    encoding=None)
688
689        with open('stdout', 'rb') as file:
690            stdout_data = file.read()
691
692        self.assertEqual(stdout_data, data)
693        self.assertEqual(result.stdout, b'')
694        self.assertEqual(result.stderr, data)
695
696    @asynctest
697    async def test_stdout_pathlib(self):
698        """Test with stdout redirected to a file name specified by pathlib"""
699
700        data = str(id(self))
701
702        async with self.connect() as conn:
703            result = await conn.run('echo', input=data, stdout=Path('stdout'))
704
705        with open('stdout', 'r') as file:
706            stdout_data = file.read()
707
708        self.assertEqual(stdout_data, data)
709        self.assertEqual(result.stdout, '')
710        self.assertEqual(result.stderr, data)
711
712    @asynctest
713    async def test_stdout_open_file(self):
714        """Test with stdout redirected to an open file"""
715
716        data = str(id(self))
717
718        file = open('stdout', 'w')
719
720        async with self.connect() as conn:
721            result = await conn.run('echo', input=data, stdout=file)
722
723        with open('stdout', 'r') as file:
724            stdout_data = file.read()
725
726        self.assertEqual(stdout_data, data)
727        self.assertEqual(result.stdout, '')
728        self.assertEqual(result.stderr, data)
729
730    @asynctest
731    async def test_stdout_open_binary_file(self):
732        """Test with stdout redirected to an open binary file"""
733
734        data = str(id(self)).encode() + b'\xff'
735
736        file = open('stdout', 'wb')
737
738        async with self.connect() as conn:
739            result = await conn.run('echo', input=data, stdout=file,
740                                    encoding=None)
741
742        with open('stdout', 'rb') as file:
743            stdout_data = file.read()
744
745        self.assertEqual(stdout_data, data)
746        self.assertEqual(result.stdout, b'')
747        self.assertEqual(result.stderr, data)
748
749    @asynctest
750    async def test_stdout_stringio(self):
751        """Test with stdout redirected to a StringIO"""
752
753        class _StringIOTest(io.StringIO):
754            """Test class for StringIO which preserves output after close"""
755
756            def __init__(self):
757                super().__init__()
758                self.output = None
759
760            def close(self):
761                if self.output is None:
762                    self.output = self.getvalue()
763                    super().close()
764
765        data = str(id(self))
766
767        file = _StringIOTest()
768
769        async with self.connect() as conn:
770            result = await conn.run('echo', input=data, stdout=file)
771
772        self.assertEqual(file.output, data)
773        self.assertEqual(result.stdout, '')
774        self.assertEqual(result.stderr, data)
775
776    @asynctest
777    async def test_stdout_bytesio(self):
778        """Test with stdout redirected to a BytesIO"""
779
780        class _BytesIOTest(io.BytesIO):
781            """Test class for BytesIO which preserves output after close"""
782
783            def __init__(self):
784                super().__init__()
785                self.output = None
786
787            def close(self):
788                if self.output is None:
789                    self.output = self.getvalue()
790                    super().close()
791
792        data = str(id(self))
793
794        file = _BytesIOTest()
795
796        async with self.connect() as conn:
797            result = await conn.run('echo', input=data, stdout=file)
798
799        self.assertEqual(file.output, data.encode('ascii'))
800        self.assertEqual(result.stdout, '')
801        self.assertEqual(result.stderr, data)
802
803    @asynctest
804    async def test_stdout_process(self):
805        """Test with stdout redirected to another SSH process"""
806
807        data = str(id(self))
808
809        async with self.connect() as conn:
810            async with conn.create_process() as proc2:
811                proc1 = await conn.create_process(stdout=proc2.stdin)
812
813                proc1.stdin.write(data)
814                proc1.stdin.write_eof()
815
816                result = await proc2.wait()
817
818        self.assertEqual(result.stdout, data)
819        self.assertEqual(result.stderr, data)
820
821    @unittest.skipIf(sys.platform == 'win32',
822                     'skip asyncio.subprocess tests on Windows')
823    @asynctest
824    async def test_stdout_stream(self):
825        """Test with stdout redirected to an asyncio stream"""
826
827        data = str(id(self))
828
829        async with self.connect() as conn:
830            proc2 = await asyncio.create_subprocess_shell(
831                'cat', stdin=asyncio.subprocess.PIPE,
832                stdout=asyncio.subprocess.PIPE)
833
834            proc1 = await conn.create_process(stdout=proc2.stdin,
835                                              stderr=asyncssh.DEVNULL)
836
837            proc1.stdin.write(data)
838            proc1.stdin.write_eof()
839
840            stdout_data, _ = await proc2.communicate()
841
842        self.assertEqual(stdout_data, data.encode('ascii'))
843
844    @asynctest
845    async def test_change_stdout(self):
846        """Test changing stdout of an open process"""
847
848        async with self.connect() as conn:
849            process = await conn.create_process(stdout='stdout')
850
851            process.stdin.write('xxx')
852
853            await asyncio.sleep(0.1)
854
855            await process.redirect_stdout(asyncssh.PIPE)
856            process.stdin.write('yyy')
857            process.stdin.write_eof()
858
859            result = await process.wait()
860
861        with open('stdout', 'r') as file:
862            stdout_data = file.read()
863
864        self.assertEqual(stdout_data, 'xxx')
865        self.assertEqual(result.stdout, 'yyy')
866        self.assertEqual(result.stderr, 'xxxyyy')
867
868    @asynctest
869    async def test_change_stdin_process(self):
870        """Test changing stdin of an open process reading from another"""
871
872        data = str(id(self))
873
874        async with self.connect() as conn:
875            async with conn.create_process() as proc2:
876                proc1 = await conn.create_process(stdout=proc2.stdin)
877
878                proc1.stdin.write(data)
879                await asyncio.sleep(0.1)
880
881                await proc2.redirect_stdin(asyncssh.PIPE)
882                proc2.stdin.write(data)
883                await asyncio.sleep(0.1)
884
885                await proc2.redirect_stdin(proc1.stdout)
886                proc1.stdin.write_eof()
887
888                result = await proc2.wait()
889
890        self.assertEqual(result.stdout, data+data)
891        self.assertEqual(result.stderr, data+data)
892
893    @asynctest
894    async def test_change_stdout_process(self):
895        """Test changing stdout of an open process sending to another"""
896
897        data = str(id(self))
898
899        async with self.connect() as conn:
900            async with conn.create_process() as proc2:
901                proc1 = await conn.create_process(stdout=proc2.stdin)
902
903                proc1.stdin.write(data)
904                await asyncio.sleep(0.1)
905
906                await proc1.redirect_stdout(asyncssh.DEVNULL)
907                proc1.stdin.write(data)
908                await asyncio.sleep(0.1)
909
910                await proc1.redirect_stdout(proc2.stdin)
911                proc1.stdin.write_eof()
912
913                result = await proc2.wait()
914
915        self.assertEqual(result.stdout, data)
916        self.assertEqual(result.stderr, data)
917
918    @asynctest
919    async def test_stderr_stdout(self):
920        """Test with stderr redirected to stdout"""
921
922        data = str(id(self))
923
924        async with self.connect() as conn:
925            result = await conn.run('echo', input=data,
926                                    stderr=asyncssh.STDOUT)
927
928        self.assertEqual(result.stdout, data+data)
929
930    @asynctest
931    async def test_server_redirect_stdin(self):
932        """Test redirect on server of stdin"""
933
934        data = str(id(self))
935
936        async with self.connect() as conn:
937            result = await conn.run('redirect_stdin', input=data)
938
939        self.assertEqual(result.stdout, data)
940        self.assertEqual(result.stderr, '')
941
942    @asynctest
943    async def test_server_redirect_stdout(self):
944        """Test redirect on server of stdout"""
945
946        data = str(id(self))
947
948        async with self.connect() as conn:
949            result = await conn.run('redirect_stdout', input=data)
950
951        self.assertEqual(result.stdout, data)
952        self.assertEqual(result.stderr, '')
953
954    @asynctest
955    async def test_server_redirect_stderr(self):
956        """Test redirect on server of stderr"""
957
958        data = str(id(self))
959
960        async with self.connect() as conn:
961            result = await conn.run('redirect_stderr', input=data)
962
963        self.assertEqual(result.stdout, '')
964        self.assertEqual(result.stderr, data)
965
966    @asynctest
967    async def test_pause_file_reader(self):
968        """Test pausing and resuming reading from a file"""
969
970        data = 4*1024*1024*'*'
971
972        with open('stdin', 'w') as file:
973            file.write(data)
974
975        async with self.connect() as conn:
976            result = await conn.run('echo', stdin='stdin',
977                                    stderr=asyncssh.DEVNULL)
978
979        self.assertEqual(result.stdout, data)
980
981    @asynctest
982    async def test_pause_process_reader(self):
983        """Test pausing and resuming reading from another SSH process"""
984
985        data = 4*1024*1024*'*'
986
987        async with self.connect() as conn:
988            proc1 = await conn.create_process(input=data)
989
990            proc2 = await conn.create_process('delay', stdin=proc1.stdout,
991                                              stderr=asyncssh.DEVNULL)
992            proc3 = await conn.create_process('delay', stdin=proc1.stderr,
993                                              stderr=asyncssh.DEVNULL)
994
995            result2, result3 = await asyncio.gather(proc2.wait(), proc3.wait())
996
997        self.assertEqual(result2.stdout, data)
998        self.assertEqual(result3.stdout, data)
999
1000    @asynctest
1001    async def test_redirect_stdin_when_paused(self):
1002        """Test redirecting stdin when write is paused"""
1003
1004        data = 4*1024*1024*'*'
1005
1006        with open('stdin', 'w') as file:
1007            file.write(data)
1008
1009        async with self.connect() as conn:
1010            process = await conn.create_process()
1011
1012            process.stdin.write(data)
1013
1014            await process.redirect_stdin('stdin')
1015
1016            result = await process.wait()
1017
1018        self.assertEqual(result.stdout, data+data)
1019        self.assertEqual(result.stderr, data+data)
1020
1021    @asynctest
1022    async def test_redirect_process_when_paused(self):
1023        """Test redirecting away from a process when write is paused"""
1024
1025        data = 4*1024*1024*'*'
1026
1027        async with self.connect() as conn:
1028            proc1 = await conn.create_process(input=data)
1029            proc2 = await conn.create_process('delay', stdin=proc1.stdout)
1030            proc3 = await conn.create_process('delay', stdin=proc1.stderr)
1031
1032            await proc1.redirect_stderr(asyncssh.DEVNULL)
1033
1034            result = await proc2.wait()
1035            proc3.close()
1036
1037        self.assertEqual(result.stdout, data)
1038        self.assertEqual(result.stderr, data)
1039
1040    @asynctest
1041    async def test_consecutive_redirect(self):
1042        """Test consecutive redirects using drain"""
1043
1044        data = 4*1024*1024*'*'
1045
1046        with open('stdin', 'w') as file:
1047            file.write(data)
1048
1049        async with self.connect() as conn:
1050            process = await conn.create_process()
1051
1052            await process.redirect_stdin('stdin', send_eof=False)
1053            await process.stdin.drain()
1054
1055            await process.redirect_stdin('stdin')
1056
1057            result = await process.wait()
1058
1059        self.assertEqual(result.stdout, data+data)
1060        self.assertEqual(result.stderr, data+data)
1061
1062
1063@unittest.skipUnless(_aiofiles_available, 'Async file I/O not available')
1064class _TestAsyncFileRedirection(_TestProcess):
1065    """Unit tests for AsyncSSH async file redirection"""
1066
1067    @asynctest
1068    async def test_stdin_aiofile(self):
1069        """Test with stdin redirected to an aiofile"""
1070
1071        data = str(id(self))
1072
1073        with open('stdin', 'w') as file:
1074            file.write(data)
1075
1076        file = await aiofiles.open('stdin', 'r')
1077
1078        async with self.connect() as conn:
1079            result = await conn.run('echo', stdin=file)
1080
1081        self.assertEqual(result.stdout, data)
1082        self.assertEqual(result.stderr, data)
1083
1084    @asynctest
1085    async def test_stdin_binary_aiofile(self):
1086        """Test with stdin redirected to an aiofile in binary mode"""
1087
1088        data = str(id(self)).encode() + b'\xff'
1089
1090        with open('stdin', 'wb') as file:
1091            file.write(data)
1092
1093        file = await aiofiles.open('stdin', 'rb')
1094
1095        async with self.connect() as conn:
1096            result = await conn.run('echo', stdin=file, encoding=None)
1097
1098        self.assertEqual(result.stdout, data)
1099        self.assertEqual(result.stderr, data)
1100
1101    @asynctest
1102    async def test_stdout_aiofile(self):
1103        """Test with stdout redirected to an aiofile"""
1104
1105        data = str(id(self))
1106
1107        file = open('stdout', 'w')
1108
1109        async with self.connect() as conn:
1110            result = await conn.run('echo', input=data, stdout=file)
1111
1112        with open('stdout', 'r') as file:
1113            stdout_data = file.read()
1114
1115        self.assertEqual(stdout_data, data)
1116        self.assertEqual(result.stdout, '')
1117        self.assertEqual(result.stderr, data)
1118
1119    @asynctest
1120    async def test_stdout_binary_aiofile(self):
1121        """Test with stdout redirected to an aiofile in binary mode"""
1122
1123        data = str(id(self)).encode() + b'\xff'
1124
1125        file = await aiofiles.open('stdout', 'wb')
1126
1127        async with self.connect() as conn:
1128            result = await conn.run('echo', input=data, stdout=file,
1129                                    encoding=None)
1130
1131        with open('stdout', 'rb') as file:
1132            stdout_data = file.read()
1133
1134        self.assertEqual(stdout_data, data)
1135        self.assertEqual(result.stdout, b'')
1136        self.assertEqual(result.stderr, data)
1137
1138    @asynctest
1139    async def test_pause_async_file_reader(self):
1140        """Test pausing and resuming reading from an aiofile"""
1141
1142        data = 4*1024*1024*'*'
1143
1144        with open('stdin', 'w') as file:
1145            file.write(data)
1146
1147        file = await aiofiles.open('stdin', 'r')
1148
1149        async with self.connect() as conn:
1150            result = await conn.run('delay', stdin=file,
1151                                    stderr=asyncssh.DEVNULL)
1152
1153        self.assertEqual(result.stdout, data)
1154
1155
1156@unittest.skipIf(sys.platform == 'win32', 'skip pipe tests on Windows')
1157class _TestProcessPipes(_TestProcess):
1158    """Unit tests for AsyncSSH process I/O using pipes"""
1159
1160    @asynctest
1161    async def test_stdin_pipe(self):
1162        """Test with stdin redirected to a pipe"""
1163
1164        data = str(id(self))
1165
1166        rpipe, wpipe = os.pipe()
1167
1168        os.write(wpipe, data.encode())
1169        os.close(wpipe)
1170
1171        async with self.connect() as conn:
1172            result = await conn.run('echo', stdin=rpipe)
1173
1174        self.assertEqual(result.stdout, data)
1175        self.assertEqual(result.stderr, data)
1176
1177    @asynctest
1178    async def test_stdin_text_pipe(self):
1179        """Test with stdin redirected to a pipe in text mode"""
1180
1181        data = str(id(self))
1182
1183        rpipe, wpipe = os.pipe()
1184
1185        rpipe = os.fdopen(rpipe, 'r')
1186        wpipe = os.fdopen(wpipe, 'w')
1187
1188        wpipe.write(data)
1189        wpipe.close()
1190
1191        async with self.connect() as conn:
1192            result = await conn.run('echo', stdin=rpipe)
1193
1194        self.assertEqual(result.stdout, data)
1195        self.assertEqual(result.stderr, data)
1196
1197    @asynctest
1198    async def test_stdin_binary_pipe(self):
1199        """Test with stdin redirected to a pipe in binary mode"""
1200
1201        data = str(id(self)).encode() + b'\xff'
1202
1203        rpipe, wpipe = os.pipe()
1204
1205        os.write(wpipe, data)
1206        os.close(wpipe)
1207
1208        async with self.connect() as conn:
1209            result = await conn.run('echo', stdin=rpipe, encoding=None)
1210
1211        self.assertEqual(result.stdout, data)
1212        self.assertEqual(result.stderr, data)
1213
1214    @asynctest
1215    async def test_stdout_pipe(self):
1216        """Test with stdout redirected to a pipe"""
1217
1218        data = str(id(self))
1219
1220        rpipe, wpipe = os.pipe()
1221
1222        async with self.connect() as conn:
1223            result = await conn.run('echo', input=data, stdout=wpipe)
1224
1225        stdout_data = os.read(rpipe, 1024)
1226        os.close(rpipe)
1227
1228        self.assertEqual(stdout_data.decode(), data)
1229        self.assertEqual(result.stdout, '')
1230        self.assertEqual(result.stderr, data)
1231
1232    @asynctest
1233    async def test_stdout_text_pipe(self):
1234        """Test with stdout redirected to a pipe in text mode"""
1235
1236        data = str(id(self))
1237
1238        rpipe, wpipe = os.pipe()
1239
1240        rpipe = os.fdopen(rpipe, 'r')
1241        wpipe = os.fdopen(wpipe, 'w')
1242
1243        async with self.connect() as conn:
1244            result = await conn.run('echo', input=data, stdout=wpipe)
1245
1246        stdout_data = rpipe.read(1024)
1247        rpipe.close()
1248
1249        self.assertEqual(stdout_data, data)
1250        self.assertEqual(result.stdout, '')
1251        self.assertEqual(result.stderr, data)
1252
1253    @asynctest
1254    async def test_stdout_binary_pipe(self):
1255        """Test with stdout redirected to a pipe in binary mode"""
1256
1257        data = str(id(self)).encode() + b'\xff'
1258
1259        rpipe, wpipe = os.pipe()
1260
1261        async with self.connect() as conn:
1262            result = await conn.run('echo', input=data, stdout=wpipe,
1263                                    encoding=None)
1264
1265        stdout_data = os.read(rpipe, 1024)
1266        os.close(rpipe)
1267
1268        self.assertEqual(stdout_data, data)
1269        self.assertEqual(result.stdout, b'')
1270        self.assertEqual(result.stderr, data)
1271
1272
1273@unittest.skipIf(sys.platform == 'win32', 'skip socketpair tests on Windows')
1274class _TestProcessSocketPair(_TestProcess):
1275    """Unit tests for AsyncSSH process I/O using socketpair"""
1276
1277    @asynctest
1278    async def test_stdin_socketpair(self):
1279        """Test with stdin redirected to a socketpair"""
1280
1281        data = str(id(self))
1282
1283        sock1, sock2 = socket.socketpair()
1284
1285        sock1.send(data.encode())
1286        sock1.close()
1287
1288        async with self.connect() as conn:
1289            result = await conn.run('echo', stdin=sock2)
1290
1291        self.assertEqual(result.stdout, data)
1292        self.assertEqual(result.stderr, data)
1293
1294    @asynctest
1295    async def test_change_stdin(self):
1296        """Test changing stdin of an open process"""
1297
1298        sock1, sock2 = socket.socketpair()
1299        sock3, sock4 = socket.socketpair()
1300
1301        sock1.send(b'xxx')
1302        sock3.send(b'yyy')
1303
1304        async with self.connect() as conn:
1305            process = await conn.create_process(stdin=sock2)
1306
1307            await asyncio.sleep(0.1)
1308            await process.redirect_stdin(sock4)
1309
1310            sock1.close()
1311            sock3.close()
1312
1313            result = await process.wait()
1314
1315        self.assertEqual(result.stdout, 'xxxyyy')
1316        self.assertEqual(result.stderr, 'xxxyyy')
1317
1318    @asynctest
1319    async def test_stdout_socketpair(self):
1320        """Test with stdout redirected to a socketpair"""
1321
1322        data = str(id(self))
1323
1324        sock1, sock2 = socket.socketpair()
1325
1326        async with self.connect() as conn:
1327            result = await conn.run('echo', input=data, stdout=sock1)
1328
1329        stdout_data = sock2.recv(1024)
1330        sock2.close()
1331
1332        self.assertEqual(stdout_data.decode(), data)
1333        self.assertEqual(result.stderr, data)
1334
1335    @asynctest
1336    async def test_pause_socketpair_reader(self):
1337        """Test pausing and resuming reading from a socketpair"""
1338
1339        data = 4*1024*1024*'*'
1340
1341        sock1, sock2 = socket.socketpair()
1342
1343        _, writer = await asyncio.open_unix_connection(sock=sock1)
1344        writer.write(data.encode())
1345        writer.close()
1346
1347        async with self.connect() as conn:
1348            result = await conn.run('delay', stdin=sock2,
1349                                    stderr=asyncssh.DEVNULL)
1350
1351        self.assertEqual(result.stdout, data)
1352
1353    @asynctest
1354    async def test_pause_socketpair_writer(self):
1355        """Test pausing and resuming writing to a socketpair"""
1356
1357        data = 4*1024*1024*'*'
1358
1359        rsock1, wsock1 = socket.socketpair()
1360        rsock2, wsock2 = socket.socketpair()
1361
1362        reader1, writer1 = await asyncio.open_unix_connection(sock=rsock1)
1363        reader2, writer2 = await asyncio.open_unix_connection(sock=rsock2)
1364
1365        async with self.connect() as conn:
1366            process = await conn.create_process(input=data)
1367
1368            await asyncio.sleep(1)
1369
1370            await process.redirect_stdout(wsock1)
1371            await process.redirect_stderr(wsock2)
1372
1373            stdout_data, stderr_data = \
1374                await asyncio.gather(reader1.read(), reader2.read())
1375
1376            writer1.close()
1377            writer2.close()
1378
1379            await process.wait()
1380
1381        self.assertEqual(stdout_data.decode(), data)
1382        self.assertEqual(stderr_data.decode(), data)
1383