1#
2# Wireshark tests
3# By Gerald Combs <gerald@wireshark.org>
4#
5# Ported from a set of Bash scripts which were copyright 2005 Ulf Lamping
6#
7# SPDX-License-Identifier: GPL-2.0-or-later
8#
9'''Subprocess test case superclass'''
10
11import difflib
12import io
13import os
14import os.path
15import re
16import subprocess
17import sys
18import unittest
19
20# To do:
21# - Add a subprocesstest.SkipUnlessCapture decorator?
22# - Try to catch crashes? See the comments below in waitProcess.
23
24process_timeout = 300 # Seconds
25
26def cat_dhcp_command(mode):
27    '''Create a command string for dumping dhcp.pcap to stdout'''
28    # XXX Do this in Python in a thread?
29    sd_cmd = ''
30    if sys.executable:
31        sd_cmd = '"{}" '.format(sys.executable)
32    this_dir = os.path.dirname(__file__)
33    sd_cmd += os.path.join(this_dir, 'util_dump_dhcp_pcap.py ' + mode)
34    return sd_cmd
35
36def cat_cap_file_command(cap_files):
37    '''Create a command string for dumping one or more capture files to stdout'''
38    # XXX Do this in Python in a thread?
39    if isinstance(cap_files, str):
40        cap_files = [ cap_files ]
41    quoted_paths = ' '.join('"{}"'.format(cap_file) for cap_file in cap_files)
42    if sys.platform.startswith('win32'):
43        # https://docs.microsoft.com/en-us/previous-versions/windows/it-pro/windows-xp/bb491026(v=technet.10)
44        # says that the `type` command "displays the contents of a text
45        # file." Copy to the console instead.
46        return 'copy {} CON'.format(quoted_paths)
47    return 'cat {}'.format(quoted_paths)
48
49class LoggingPopen(subprocess.Popen):
50    '''Run a process using subprocess.Popen. Capture and log its output.
51
52    Stdout and stderr are captured to memory and decoded as UTF-8. On
53    Windows, CRLF line endings are normalized to LF. The program command
54    and output is written to log_fd.
55    '''
56    def __init__(self, proc_args, *args, **kwargs):
57        self.log_fd = kwargs.pop('log_fd', None)
58        self.max_lines = kwargs.pop('max_lines', None)
59        kwargs['stdout'] = subprocess.PIPE
60        kwargs['stderr'] = subprocess.PIPE
61        # Make sure communicate() gives us bytes.
62        kwargs['universal_newlines'] = False
63        self.cmd_str = 'command ' + repr(proc_args)
64        super().__init__(proc_args, *args, **kwargs)
65        self.stdout_str = ''
66        self.stderr_str = ''
67
68    @staticmethod
69    def trim_output(out_log, max_lines):
70        lines = out_log.splitlines(True)
71        if not len(lines) > max_lines * 2 + 1:
72            return out_log
73        header = lines[:max_lines]
74        body = lines[max_lines:-max_lines]
75        body = "<<< trimmed {} lines of output >>>\n".format(len(body))
76        footer = lines[-max_lines:]
77        return ''.join(header) + body + ''.join(footer)
78
79    def wait_and_log(self):
80        '''Wait for the process to finish and log its output.'''
81        out_data, err_data = self.communicate(timeout=process_timeout)
82        out_log = out_data.decode('UTF-8', 'replace')
83        if self.max_lines and self.max_lines > 0:
84            out_log = self.trim_output(out_log, self.max_lines)
85        err_log = err_data.decode('UTF-8', 'replace')
86        self.log_fd.flush()
87        self.log_fd.write('-- Begin stdout for {} --\n'.format(self.cmd_str))
88        self.log_fd.write(out_log)
89        self.log_fd.write('-- End stdout for {} --\n'.format(self.cmd_str))
90        self.log_fd.write('-- Begin stderr for {} --\n'.format(self.cmd_str))
91        self.log_fd.write(err_log)
92        self.log_fd.write('-- End stderr for {} --\n'.format(self.cmd_str))
93        self.log_fd.flush()
94        # Make sure our output is the same everywhere.
95        # Throwing a UnicodeDecodeError exception here is arguably a good thing.
96        self.stdout_str = out_data.decode('UTF-8', 'strict').replace('\r\n', '\n')
97        self.stderr_str = err_data.decode('UTF-8', 'strict').replace('\r\n', '\n')
98
99    def stop_process(self, kill=False):
100        '''Stop the process immediately.'''
101        if kill:
102            super().kill()
103        else:
104            super().terminate()
105
106    def terminate(self):
107        '''Terminate the process. Do not log its output.'''
108        # XXX Currently unused.
109        self.stop_process(kill=False)
110
111    def kill(self):
112        '''Kill the process. Do not log its output.'''
113        self.stop_process(kill=True)
114
115class SubprocessTestCase(unittest.TestCase):
116    '''Run a program and gather its stdout and stderr.'''
117
118    def __init__(self, *args, **kwargs):
119        super().__init__(*args, **kwargs)
120        self.exit_ok = 0
121        self.exit_command_line = 1
122        self.exit_error = 2
123        self.exit_code = None
124        self.log_fname = None
125        self.log_fd = None
126        self.processes = []
127        self.cleanup_files = []
128        self.dump_files = []
129
130    def log_fd_write_bytes(self, log_data):
131        self.log_fd.write(log_data)
132
133    def filename_from_id(self, filename):
134        '''Generate a filename prefixed with our test ID.'''
135        id_filename = self.id() + '.' + filename
136        if id_filename not in self.cleanup_files:
137            self.cleanup_files.append(id_filename)
138        return id_filename
139
140    def kill_processes(self):
141        '''Kill any processes we've opened so far'''
142        for proc in self.processes:
143            try:
144                proc.kill()
145            except Exception:
146                pass
147
148    def setUp(self):
149        """
150        Set up a single test. Opens a log file and add it to the cleanup list.
151        """
152        self.processes = []
153        self.log_fname = self.filename_from_id('log')
154        # Our command line utilities generate UTF-8. The log file endcoding
155        # needs to match that.
156        # XXX newline='\n' works for now, but we might have to do more work
157        # to handle line endings in the future.
158        self.log_fd = io.open(self.log_fname, 'w', encoding='UTF-8', newline='\n')
159        self.cleanup_files.append(self.log_fname)
160
161    def _last_test_failed(self):
162        """Check for non-skipped tests that resulted in errors."""
163        # The test outcome is not available via the public unittest API, so
164        # check a private property, "_outcome", set by unittest.TestCase.run.
165        # It remains None when running in debug mode (`pytest --pdb`).
166        # The property is available since Python 3.4 until at least Python 3.7.
167        if self._outcome:
168            for test_case, exc_info in self._outcome.errors:
169                if exc_info:
170                    return True
171        # No errors occurred or running in debug mode.
172        return False
173
174    def tearDown(self):
175        """
176        Tears down a single test. Kills stray processes and closes the log file.
177        On errors, display the log contents. On success, remove temporary files.
178        """
179        self.kill_processes()
180        self.log_fd.close()
181        if self._last_test_failed():
182            self.dump_files.append(self.log_fname)
183            # Leave some evidence behind.
184            self.cleanup_files = []
185            print('\nProcess output for {}:'.format(self.id()))
186            with io.open(self.log_fname, 'r', encoding='UTF-8', errors='backslashreplace') as log_fd:
187                for line in log_fd:
188                    sys.stdout.write(line)
189        for filename in self.cleanup_files:
190            try:
191                os.unlink(filename)
192            except OSError:
193                pass
194        self.cleanup_files = []
195
196    def getCaptureInfo(self, capinfos_args=None, cap_file=None):
197        '''Run capinfos on a capture file and log its output.
198
199        capinfos_args must be a sequence.
200        Default cap_file is <test id>.testout.pcap.'''
201        # XXX convert users to use a new fixture instead of this function.
202        cmd_capinfos = self._fixture_request.getfixturevalue('cmd_capinfos')
203        if not cap_file:
204            cap_file = self.filename_from_id('testout.pcap')
205        self.log_fd.write('\nOutput of {0} {1}:\n'.format(cmd_capinfos, cap_file))
206        capinfos_cmd = [cmd_capinfos]
207        if capinfos_args is not None:
208            capinfos_cmd += capinfos_args
209        capinfos_cmd.append(cap_file)
210        capinfos_data = subprocess.check_output(capinfos_cmd)
211        capinfos_stdout = capinfos_data.decode('UTF-8', 'replace')
212        self.log_fd.write(capinfos_stdout)
213        return capinfos_stdout
214
215    def checkPacketCount(self, num_packets, cap_file=None):
216        '''Make sure a capture file contains a specific number of packets.'''
217        got_num_packets = False
218        capinfos_testout = self.getCaptureInfo(cap_file=cap_file)
219        count_pat = r'Number of packets:\s+{}'.format(num_packets)
220        if re.search(count_pat, capinfos_testout):
221            got_num_packets = True
222        self.assertTrue(got_num_packets, 'Failed to capture exactly {} packets'.format(num_packets))
223
224    def countOutput(self, search_pat=None, count_stdout=True, count_stderr=False, proc=None):
225        '''Returns the number of output lines (search_pat=None), otherwise returns a match count.'''
226        match_count = 0
227        self.assertTrue(count_stdout or count_stderr, 'No output to count.')
228
229        if proc is None:
230            proc = self.processes[-1]
231
232        out_data = ''
233        if count_stdout:
234            out_data = proc.stdout_str
235        if count_stderr:
236            out_data += proc.stderr_str
237
238        if search_pat is None:
239            return len(out_data.splitlines())
240
241        search_re = re.compile(search_pat)
242        for line in out_data.splitlines():
243            if search_re.search(line):
244                match_count += 1
245
246        return match_count
247
248    def grepOutput(self, search_pat, proc=None):
249        return self.countOutput(search_pat, count_stderr=True, proc=proc) > 0
250
251    def diffOutput(self, blob_a, blob_b, *args, **kwargs):
252        '''Check for differences between blob_a and blob_b. Return False and log a unified diff if they differ.
253
254        blob_a and blob_b must be UTF-8 strings.'''
255        lines_a = blob_a.splitlines()
256        lines_b = blob_b.splitlines()
257        diff = '\n'.join(list(difflib.unified_diff(lines_a, lines_b, *args, **kwargs)))
258        if len(diff) > 0:
259            self.log_fd.flush()
260            self.log_fd.write('-- Begin diff output --\n')
261            self.log_fd.writelines(diff)
262            self.log_fd.write('-- End diff output --\n')
263            return False
264        return True
265
266    def startProcess(self, proc_args, stdin=None, env=None, shell=False, cwd=None, max_lines=None):
267        '''Start a process in the background. Returns a subprocess.Popen object.
268
269        You typically wait for it using waitProcess() or assertWaitProcess().'''
270        if env is None:
271            # Apply default test environment if no override is provided.
272            env = getattr(self, 'injected_test_env', None)
273            # Not all tests need test_env, but those that use runProcess or
274            # startProcess must either pass an explicit environment or load the
275            # fixture (via a test method parameter or class decorator).
276            assert not (env is None and hasattr(self, '_fixture_request')), \
277                "Decorate class with @fixtures.mark_usefixtures('test_env')"
278        proc = LoggingPopen(proc_args, stdin=stdin, env=env, shell=shell, log_fd=self.log_fd, cwd=cwd, max_lines=max_lines)
279        self.processes.append(proc)
280        return proc
281
282    def waitProcess(self, process):
283        '''Wait for a process to finish.'''
284        process.wait_and_log()
285        # XXX The shell version ran processes using a script called run_and_catch_crashes
286        # which looked for core dumps and printed stack traces if found. We might want
287        # to do something similar here. This may not be easy on modern Ubuntu systems,
288        # which default to using Apport: https://wiki.ubuntu.com/Apport
289
290    def assertWaitProcess(self, process, expected_return=0):
291        '''Wait for a process to finish and check its exit code.'''
292        process.wait_and_log()
293        self.assertEqual(process.returncode, expected_return)
294
295    def runProcess(self, args, env=None, shell=False, cwd=None, max_lines=None):
296        '''Start a process and wait for it to finish.'''
297        process = self.startProcess(args, env=env, shell=shell, cwd=cwd, max_lines=max_lines)
298        process.wait_and_log()
299        return process
300
301    def assertRun(self, args, env=None, shell=False, expected_return=0, cwd=None, max_lines=None):
302        '''Start a process and wait for it to finish. Check its return code.'''
303        process = self.runProcess(args, env=env, shell=shell, cwd=cwd, max_lines=max_lines)
304        self.assertEqual(process.returncode, expected_return)
305        return process
306