1# Common utilities and Python wrappers for qemu-iotests
2#
3# Copyright (C) 2012 IBM Corp.
4#
5# This program is free software; you can redistribute it and/or modify
6# it under the terms of the GNU General Public License as published by
7# the Free Software Foundation; either version 2 of the License, or
8# (at your option) any later version.
9#
10# This program is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU General Public License for more details.
14#
15# You should have received a copy of the GNU General Public License
16# along with this program.  If not, see <http://www.gnu.org/licenses/>.
17#
18
19import atexit
20import bz2
21from collections import OrderedDict
22import faulthandler
23import json
24import logging
25import os
26import re
27import shutil
28import signal
29import struct
30import subprocess
31import sys
32import time
33from typing import (Any, Callable, Dict, Iterable, Iterator,
34                    List, Optional, Sequence, TextIO, Tuple, Type, TypeVar)
35import unittest
36
37from contextlib import contextmanager
38
39from qemu.machine import qtest
40from qemu.qmp import QMPMessage
41
42# Use this logger for logging messages directly from the iotests module
43logger = logging.getLogger('qemu.iotests')
44logger.addHandler(logging.NullHandler())
45
46# Use this logger for messages that ought to be used for diff output.
47test_logger = logging.getLogger('qemu.iotests.diff_io')
48
49
50faulthandler.enable()
51
52# This will not work if arguments contain spaces but is necessary if we
53# want to support the override options that ./check supports.
54qemu_img_args = [os.environ.get('QEMU_IMG_PROG', 'qemu-img')]
55if os.environ.get('QEMU_IMG_OPTIONS'):
56    qemu_img_args += os.environ['QEMU_IMG_OPTIONS'].strip().split(' ')
57
58qemu_io_args = [os.environ.get('QEMU_IO_PROG', 'qemu-io')]
59if os.environ.get('QEMU_IO_OPTIONS'):
60    qemu_io_args += os.environ['QEMU_IO_OPTIONS'].strip().split(' ')
61
62qemu_io_args_no_fmt = [os.environ.get('QEMU_IO_PROG', 'qemu-io')]
63if os.environ.get('QEMU_IO_OPTIONS_NO_FMT'):
64    qemu_io_args_no_fmt += \
65        os.environ['QEMU_IO_OPTIONS_NO_FMT'].strip().split(' ')
66
67qemu_nbd_prog = os.environ.get('QEMU_NBD_PROG', 'qemu-nbd')
68qemu_nbd_args = [qemu_nbd_prog]
69if os.environ.get('QEMU_NBD_OPTIONS'):
70    qemu_nbd_args += os.environ['QEMU_NBD_OPTIONS'].strip().split(' ')
71
72qemu_prog = os.environ.get('QEMU_PROG', 'qemu')
73qemu_opts = os.environ.get('QEMU_OPTIONS', '').strip().split(' ')
74
75gdb_qemu_env = os.environ.get('GDB_OPTIONS')
76qemu_gdb = []
77if gdb_qemu_env:
78    qemu_gdb = ['gdbserver'] + gdb_qemu_env.strip().split(' ')
79
80qemu_print = os.environ.get('PRINT_QEMU', False)
81
82imgfmt = os.environ.get('IMGFMT', 'raw')
83imgproto = os.environ.get('IMGPROTO', 'file')
84output_dir = os.environ.get('OUTPUT_DIR', '.')
85
86try:
87    test_dir = os.environ['TEST_DIR']
88    sock_dir = os.environ['SOCK_DIR']
89    cachemode = os.environ['CACHEMODE']
90    aiomode = os.environ['AIOMODE']
91    qemu_default_machine = os.environ['QEMU_DEFAULT_MACHINE']
92except KeyError:
93    # We are using these variables as proxies to indicate that we're
94    # not being run via "check". There may be other things set up by
95    # "check" that individual test cases rely on.
96    sys.stderr.write('Please run this test via the "check" script\n')
97    sys.exit(os.EX_USAGE)
98
99qemu_valgrind = []
100if os.environ.get('VALGRIND_QEMU') == "y" and \
101    os.environ.get('NO_VALGRIND') != "y":
102    valgrind_logfile = "--log-file=" + test_dir
103    # %p allows to put the valgrind process PID, since
104    # we don't know it a priori (subprocess.Popen is
105    # not yet invoked)
106    valgrind_logfile += "/%p.valgrind"
107
108    qemu_valgrind = ['valgrind', valgrind_logfile, '--error-exitcode=99']
109
110luks_default_secret_object = 'secret,id=keysec0,data=' + \
111                             os.environ.get('IMGKEYSECRET', '')
112luks_default_key_secret_opt = 'key-secret=keysec0'
113
114sample_img_dir = os.environ['SAMPLE_IMG_DIR']
115
116
117@contextmanager
118def change_log_level(
119        logger_name: str, level: int = logging.CRITICAL) -> Iterator[None]:
120    """
121    Utility function for temporarily changing the log level of a logger.
122
123    This can be used to silence errors that are expected or uninteresting.
124    """
125    _logger = logging.getLogger(logger_name)
126    current_level = _logger.level
127    _logger.setLevel(level)
128
129    try:
130        yield
131    finally:
132        _logger.setLevel(current_level)
133
134
135def unarchive_sample_image(sample, fname):
136    sample_fname = os.path.join(sample_img_dir, sample + '.bz2')
137    with bz2.open(sample_fname) as f_in, open(fname, 'wb') as f_out:
138        shutil.copyfileobj(f_in, f_out)
139
140
141def qemu_tool_pipe_and_status(tool: str, args: Sequence[str],
142                              connect_stderr: bool = True) -> Tuple[str, int]:
143    """
144    Run a tool and return both its output and its exit code
145    """
146    stderr = subprocess.STDOUT if connect_stderr else None
147    with subprocess.Popen(args, stdout=subprocess.PIPE,
148                          stderr=stderr, universal_newlines=True) as subp:
149        output = subp.communicate()[0]
150        if subp.returncode < 0:
151            cmd = ' '.join(args)
152            sys.stderr.write(f'{tool} received signal \
153                               {-subp.returncode}: {cmd}\n')
154        return (output, subp.returncode)
155
156def qemu_img_pipe_and_status(*args: str) -> Tuple[str, int]:
157    """
158    Run qemu-img and return both its output and its exit code
159    """
160    full_args = qemu_img_args + list(args)
161    return qemu_tool_pipe_and_status('qemu-img', full_args)
162
163def qemu_img(*args: str) -> int:
164    '''Run qemu-img and return the exit code'''
165    return qemu_img_pipe_and_status(*args)[1]
166
167def ordered_qmp(qmsg, conv_keys=True):
168    # Dictionaries are not ordered prior to 3.6, therefore:
169    if isinstance(qmsg, list):
170        return [ordered_qmp(atom) for atom in qmsg]
171    if isinstance(qmsg, dict):
172        od = OrderedDict()
173        for k, v in sorted(qmsg.items()):
174            if conv_keys:
175                k = k.replace('_', '-')
176            od[k] = ordered_qmp(v, conv_keys=False)
177        return od
178    return qmsg
179
180def qemu_img_create(*args):
181    args = list(args)
182
183    # default luks support
184    if '-f' in args and args[args.index('-f') + 1] == 'luks':
185        if '-o' in args:
186            i = args.index('-o')
187            if 'key-secret' not in args[i + 1]:
188                args[i + 1].append(luks_default_key_secret_opt)
189                args.insert(i + 2, '--object')
190                args.insert(i + 3, luks_default_secret_object)
191        else:
192            args = ['-o', luks_default_key_secret_opt,
193                    '--object', luks_default_secret_object] + args
194
195    args.insert(0, 'create')
196
197    return qemu_img(*args)
198
199def qemu_img_measure(*args):
200    return json.loads(qemu_img_pipe("measure", "--output", "json", *args))
201
202def qemu_img_check(*args):
203    return json.loads(qemu_img_pipe("check", "--output", "json", *args))
204
205def qemu_img_verbose(*args):
206    '''Run qemu-img without suppressing its output and return the exit code'''
207    exitcode = subprocess.call(qemu_img_args + list(args))
208    if exitcode < 0:
209        sys.stderr.write('qemu-img received signal %i: %s\n'
210                         % (-exitcode, ' '.join(qemu_img_args + list(args))))
211    return exitcode
212
213def qemu_img_pipe(*args: str) -> str:
214    '''Run qemu-img and return its output'''
215    return qemu_img_pipe_and_status(*args)[0]
216
217def qemu_img_log(*args):
218    result = qemu_img_pipe(*args)
219    log(result, filters=[filter_testfiles])
220    return result
221
222def img_info_log(filename, filter_path=None, imgopts=False, extra_args=()):
223    args = ['info']
224    if imgopts:
225        args.append('--image-opts')
226    else:
227        args += ['-f', imgfmt]
228    args += extra_args
229    args.append(filename)
230
231    output = qemu_img_pipe(*args)
232    if not filter_path:
233        filter_path = filename
234    log(filter_img_info(output, filter_path))
235
236def qemu_io(*args):
237    '''Run qemu-io and return the stdout data'''
238    args = qemu_io_args + list(args)
239    return qemu_tool_pipe_and_status('qemu-io', args)[0]
240
241def qemu_io_log(*args):
242    result = qemu_io(*args)
243    log(result, filters=[filter_testfiles, filter_qemu_io])
244    return result
245
246def qemu_io_silent(*args):
247    '''Run qemu-io and return the exit code, suppressing stdout'''
248    if '-f' in args or '--image-opts' in args:
249        default_args = qemu_io_args_no_fmt
250    else:
251        default_args = qemu_io_args
252
253    args = default_args + list(args)
254    result = subprocess.run(args, stdout=subprocess.DEVNULL, check=False)
255    if result.returncode < 0:
256        sys.stderr.write('qemu-io received signal %i: %s\n' %
257                         (-result.returncode, ' '.join(args)))
258    return result.returncode
259
260def qemu_io_silent_check(*args):
261    '''Run qemu-io and return the true if subprocess returned 0'''
262    args = qemu_io_args + list(args)
263    result = subprocess.run(args, stdout=subprocess.DEVNULL,
264                            stderr=subprocess.STDOUT, check=False)
265    return result.returncode == 0
266
267class QemuIoInteractive:
268    def __init__(self, *args):
269        self.args = qemu_io_args_no_fmt + list(args)
270        # We need to keep the Popen objext around, and not
271        # close it immediately. Therefore, disable the pylint check:
272        # pylint: disable=consider-using-with
273        self._p = subprocess.Popen(self.args, stdin=subprocess.PIPE,
274                                   stdout=subprocess.PIPE,
275                                   stderr=subprocess.STDOUT,
276                                   universal_newlines=True)
277        out = self._p.stdout.read(9)
278        if out != 'qemu-io> ':
279            # Most probably qemu-io just failed to start.
280            # Let's collect the whole output and exit.
281            out += self._p.stdout.read()
282            self._p.wait(timeout=1)
283            raise ValueError(out)
284
285    def close(self):
286        self._p.communicate('q\n')
287
288    def _read_output(self):
289        pattern = 'qemu-io> '
290        n = len(pattern)
291        pos = 0
292        s = []
293        while pos != n:
294            c = self._p.stdout.read(1)
295            # check unexpected EOF
296            assert c != ''
297            s.append(c)
298            if c == pattern[pos]:
299                pos += 1
300            else:
301                pos = 0
302
303        return ''.join(s[:-n])
304
305    def cmd(self, cmd):
306        # quit command is in close(), '\n' is added automatically
307        assert '\n' not in cmd
308        cmd = cmd.strip()
309        assert cmd not in ('q', 'quit')
310        self._p.stdin.write(cmd + '\n')
311        self._p.stdin.flush()
312        return self._read_output()
313
314
315def qemu_nbd(*args):
316    '''Run qemu-nbd in daemon mode and return the parent's exit code'''
317    return subprocess.call(qemu_nbd_args + ['--fork'] + list(args))
318
319def qemu_nbd_early_pipe(*args: str) -> Tuple[int, str]:
320    '''Run qemu-nbd in daemon mode and return both the parent's exit code
321       and its output in case of an error'''
322    full_args = qemu_nbd_args + ['--fork'] + list(args)
323    output, returncode = qemu_tool_pipe_and_status('qemu-nbd', full_args,
324                                                   connect_stderr=False)
325    return returncode, output if returncode else ''
326
327def qemu_nbd_list_log(*args: str) -> str:
328    '''Run qemu-nbd to list remote exports'''
329    full_args = [qemu_nbd_prog, '-L'] + list(args)
330    output, _ = qemu_tool_pipe_and_status('qemu-nbd', full_args)
331    log(output, filters=[filter_testfiles, filter_nbd_exports])
332    return output
333
334@contextmanager
335def qemu_nbd_popen(*args):
336    '''Context manager running qemu-nbd within the context'''
337    pid_file = file_path("qemu_nbd_popen-nbd-pid-file")
338
339    assert not os.path.exists(pid_file)
340
341    cmd = list(qemu_nbd_args)
342    cmd.extend(('--persistent', '--pid-file', pid_file))
343    cmd.extend(args)
344
345    log('Start NBD server')
346    with subprocess.Popen(cmd) as p:
347        try:
348            while not os.path.exists(pid_file):
349                if p.poll() is not None:
350                    raise RuntimeError(
351                        "qemu-nbd terminated with exit code {}: {}"
352                        .format(p.returncode, ' '.join(cmd)))
353
354                time.sleep(0.01)
355            yield
356        finally:
357            if os.path.exists(pid_file):
358                os.remove(pid_file)
359            log('Kill NBD server')
360            p.kill()
361            p.wait()
362
363def compare_images(img1, img2, fmt1=imgfmt, fmt2=imgfmt):
364    '''Return True if two image files are identical'''
365    return qemu_img('compare', '-f', fmt1,
366                    '-F', fmt2, img1, img2) == 0
367
368def create_image(name, size):
369    '''Create a fully-allocated raw image with sector markers'''
370    with open(name, 'wb') as file:
371        i = 0
372        while i < size:
373            sector = struct.pack('>l504xl', i // 512, i // 512)
374            file.write(sector)
375            i = i + 512
376
377def image_size(img):
378    '''Return image's virtual size'''
379    r = qemu_img_pipe('info', '--output=json', '-f', imgfmt, img)
380    return json.loads(r)['virtual-size']
381
382def is_str(val):
383    return isinstance(val, str)
384
385test_dir_re = re.compile(r"%s" % test_dir)
386def filter_test_dir(msg):
387    return test_dir_re.sub("TEST_DIR", msg)
388
389win32_re = re.compile(r"\r")
390def filter_win32(msg):
391    return win32_re.sub("", msg)
392
393qemu_io_re = re.compile(r"[0-9]* ops; [0-9\/:. sec]* "
394                        r"\([0-9\/.inf]* [EPTGMKiBbytes]*\/sec "
395                        r"and [0-9\/.inf]* ops\/sec\)")
396def filter_qemu_io(msg):
397    msg = filter_win32(msg)
398    return qemu_io_re.sub("X ops; XX:XX:XX.X "
399                          "(XXX YYY/sec and XXX ops/sec)", msg)
400
401chown_re = re.compile(r"chown [0-9]+:[0-9]+")
402def filter_chown(msg):
403    return chown_re.sub("chown UID:GID", msg)
404
405def filter_qmp_event(event):
406    '''Filter a QMP event dict'''
407    event = dict(event)
408    if 'timestamp' in event:
409        event['timestamp']['seconds'] = 'SECS'
410        event['timestamp']['microseconds'] = 'USECS'
411    return event
412
413def filter_qmp(qmsg, filter_fn):
414    '''Given a string filter, filter a QMP object's values.
415    filter_fn takes a (key, value) pair.'''
416    # Iterate through either lists or dicts;
417    if isinstance(qmsg, list):
418        items = enumerate(qmsg)
419    else:
420        items = qmsg.items()
421
422    for k, v in items:
423        if isinstance(v, (dict, list)):
424            qmsg[k] = filter_qmp(v, filter_fn)
425        else:
426            qmsg[k] = filter_fn(k, v)
427    return qmsg
428
429def filter_testfiles(msg):
430    pref1 = os.path.join(test_dir, "%s-" % (os.getpid()))
431    pref2 = os.path.join(sock_dir, "%s-" % (os.getpid()))
432    return msg.replace(pref1, 'TEST_DIR/PID-').replace(pref2, 'SOCK_DIR/PID-')
433
434def filter_qmp_testfiles(qmsg):
435    def _filter(_key, value):
436        if is_str(value):
437            return filter_testfiles(value)
438        return value
439    return filter_qmp(qmsg, _filter)
440
441def filter_virtio_scsi(output: str) -> str:
442    return re.sub(r'(virtio-scsi)-(ccw|pci)', r'\1', output)
443
444def filter_qmp_virtio_scsi(qmsg):
445    def _filter(_key, value):
446        if is_str(value):
447            return filter_virtio_scsi(value)
448        return value
449    return filter_qmp(qmsg, _filter)
450
451def filter_generated_node_ids(msg):
452    return re.sub("#block[0-9]+", "NODE_NAME", msg)
453
454def filter_img_info(output, filename):
455    lines = []
456    for line in output.split('\n'):
457        if 'disk size' in line or 'actual-size' in line:
458            continue
459        line = line.replace(filename, 'TEST_IMG')
460        line = filter_testfiles(line)
461        line = line.replace(imgfmt, 'IMGFMT')
462        line = re.sub('iters: [0-9]+', 'iters: XXX', line)
463        line = re.sub('uuid: [-a-f0-9]+',
464                      'uuid: XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX',
465                      line)
466        line = re.sub('cid: [0-9]+', 'cid: XXXXXXXXXX', line)
467        lines.append(line)
468    return '\n'.join(lines)
469
470def filter_imgfmt(msg):
471    return msg.replace(imgfmt, 'IMGFMT')
472
473def filter_qmp_imgfmt(qmsg):
474    def _filter(_key, value):
475        if is_str(value):
476            return filter_imgfmt(value)
477        return value
478    return filter_qmp(qmsg, _filter)
479
480def filter_nbd_exports(output: str) -> str:
481    return re.sub(r'((min|opt|max) block): [0-9]+', r'\1: XXX', output)
482
483
484Msg = TypeVar('Msg', Dict[str, Any], List[Any], str)
485
486def log(msg: Msg,
487        filters: Iterable[Callable[[Msg], Msg]] = (),
488        indent: Optional[int] = None) -> None:
489    """
490    Logs either a string message or a JSON serializable message (like QMP).
491    If indent is provided, JSON serializable messages are pretty-printed.
492    """
493    for flt in filters:
494        msg = flt(msg)
495    if isinstance(msg, (dict, list)):
496        # Don't sort if it's already sorted
497        do_sort = not isinstance(msg, OrderedDict)
498        test_logger.info(json.dumps(msg, sort_keys=do_sort, indent=indent))
499    else:
500        test_logger.info(msg)
501
502class Timeout:
503    def __init__(self, seconds, errmsg="Timeout"):
504        self.seconds = seconds
505        self.errmsg = errmsg
506    def __enter__(self):
507        if qemu_gdb or qemu_valgrind:
508            return self
509        signal.signal(signal.SIGALRM, self.timeout)
510        signal.setitimer(signal.ITIMER_REAL, self.seconds)
511        return self
512    def __exit__(self, exc_type, value, traceback):
513        if qemu_gdb or qemu_valgrind:
514            return False
515        signal.setitimer(signal.ITIMER_REAL, 0)
516        return False
517    def timeout(self, signum, frame):
518        raise Exception(self.errmsg)
519
520def file_pattern(name):
521    return "{0}-{1}".format(os.getpid(), name)
522
523class FilePath:
524    """
525    Context manager generating multiple file names. The generated files are
526    removed when exiting the context.
527
528    Example usage:
529
530        with FilePath('a.img', 'b.img') as (img_a, img_b):
531            # Use img_a and img_b here...
532
533        # a.img and b.img are automatically removed here.
534
535    By default images are created in iotests.test_dir. To create sockets use
536    iotests.sock_dir:
537
538       with FilePath('a.sock', base_dir=iotests.sock_dir) as sock:
539
540    For convenience, calling with one argument yields a single file instead of
541    a tuple with one item.
542
543    """
544    def __init__(self, *names, base_dir=test_dir):
545        self.paths = [os.path.join(base_dir, file_pattern(name))
546                      for name in names]
547
548    def __enter__(self):
549        if len(self.paths) == 1:
550            return self.paths[0]
551        else:
552            return self.paths
553
554    def __exit__(self, exc_type, exc_val, exc_tb):
555        for path in self.paths:
556            try:
557                os.remove(path)
558            except OSError:
559                pass
560        return False
561
562
563def try_remove(img):
564    try:
565        os.remove(img)
566    except OSError:
567        pass
568
569def file_path_remover():
570    for path in reversed(file_path_remover.paths):
571        try_remove(path)
572
573
574def file_path(*names, base_dir=test_dir):
575    ''' Another way to get auto-generated filename that cleans itself up.
576
577    Use is as simple as:
578
579    img_a, img_b = file_path('a.img', 'b.img')
580    sock = file_path('socket')
581    '''
582
583    if not hasattr(file_path_remover, 'paths'):
584        file_path_remover.paths = []
585        atexit.register(file_path_remover)
586
587    paths = []
588    for name in names:
589        filename = file_pattern(name)
590        path = os.path.join(base_dir, filename)
591        file_path_remover.paths.append(path)
592        paths.append(path)
593
594    return paths[0] if len(paths) == 1 else paths
595
596def remote_filename(path):
597    if imgproto == 'file':
598        return path
599    elif imgproto == 'ssh':
600        return "ssh://%s@127.0.0.1:22%s" % (os.environ.get('USER'), path)
601    else:
602        raise Exception("Protocol %s not supported" % (imgproto))
603
604class VM(qtest.QEMUQtestMachine):
605    '''A QEMU VM'''
606
607    def __init__(self, path_suffix=''):
608        name = "qemu%s-%d" % (path_suffix, os.getpid())
609        timer = 15.0 if not (qemu_gdb or qemu_valgrind) else None
610        if qemu_gdb and qemu_valgrind:
611            sys.stderr.write('gdb and valgrind are mutually exclusive\n')
612            sys.exit(1)
613        wrapper = qemu_gdb if qemu_gdb else qemu_valgrind
614        super().__init__(qemu_prog, qemu_opts, wrapper=wrapper,
615                         name=name,
616                         base_temp_dir=test_dir,
617                         sock_dir=sock_dir, qmp_timer=timer)
618        self._num_drives = 0
619
620    def _post_shutdown(self) -> None:
621        super()._post_shutdown()
622        if not qemu_valgrind or not self._popen:
623            return
624        valgrind_filename = f"{test_dir}/{self._popen.pid}.valgrind"
625        if self.exitcode() == 99:
626            with open(valgrind_filename, encoding='utf-8') as f:
627                print(f.read())
628        else:
629            os.remove(valgrind_filename)
630
631    def _pre_launch(self) -> None:
632        super()._pre_launch()
633        if qemu_print:
634            # set QEMU binary output to stdout
635            self._close_qemu_log_file()
636
637    def add_object(self, opts):
638        self._args.append('-object')
639        self._args.append(opts)
640        return self
641
642    def add_device(self, opts):
643        self._args.append('-device')
644        self._args.append(opts)
645        return self
646
647    def add_drive_raw(self, opts):
648        self._args.append('-drive')
649        self._args.append(opts)
650        return self
651
652    def add_drive(self, path, opts='', interface='virtio', img_format=imgfmt):
653        '''Add a virtio-blk drive to the VM'''
654        options = ['if=%s' % interface,
655                   'id=drive%d' % self._num_drives]
656
657        if path is not None:
658            options.append('file=%s' % path)
659            options.append('format=%s' % img_format)
660            options.append('cache=%s' % cachemode)
661            options.append('aio=%s' % aiomode)
662
663        if opts:
664            options.append(opts)
665
666        if img_format == 'luks' and 'key-secret' not in opts:
667            # default luks support
668            if luks_default_secret_object not in self._args:
669                self.add_object(luks_default_secret_object)
670
671            options.append(luks_default_key_secret_opt)
672
673        self._args.append('-drive')
674        self._args.append(','.join(options))
675        self._num_drives += 1
676        return self
677
678    def add_blockdev(self, opts):
679        self._args.append('-blockdev')
680        if isinstance(opts, str):
681            self._args.append(opts)
682        else:
683            self._args.append(','.join(opts))
684        return self
685
686    def add_incoming(self, addr):
687        self._args.append('-incoming')
688        self._args.append(addr)
689        return self
690
691    def hmp(self, command_line: str, use_log: bool = False) -> QMPMessage:
692        cmd = 'human-monitor-command'
693        kwargs: Dict[str, Any] = {'command-line': command_line}
694        if use_log:
695            return self.qmp_log(cmd, **kwargs)
696        else:
697            return self.qmp(cmd, **kwargs)
698
699    def pause_drive(self, drive: str, event: Optional[str] = None) -> None:
700        """Pause drive r/w operations"""
701        if not event:
702            self.pause_drive(drive, "read_aio")
703            self.pause_drive(drive, "write_aio")
704            return
705        self.hmp(f'qemu-io {drive} "break {event} bp_{drive}"')
706
707    def resume_drive(self, drive: str) -> None:
708        """Resume drive r/w operations"""
709        self.hmp(f'qemu-io {drive} "remove_break bp_{drive}"')
710
711    def hmp_qemu_io(self, drive: str, cmd: str,
712                    use_log: bool = False, qdev: bool = False) -> QMPMessage:
713        """Write to a given drive using an HMP command"""
714        d = '-d ' if qdev else ''
715        return self.hmp(f'qemu-io {d}{drive} "{cmd}"', use_log=use_log)
716
717    def flatten_qmp_object(self, obj, output=None, basestr=''):
718        if output is None:
719            output = {}
720        if isinstance(obj, list):
721            for i, item in enumerate(obj):
722                self.flatten_qmp_object(item, output, basestr + str(i) + '.')
723        elif isinstance(obj, dict):
724            for key in obj:
725                self.flatten_qmp_object(obj[key], output, basestr + key + '.')
726        else:
727            output[basestr[:-1]] = obj # Strip trailing '.'
728        return output
729
730    def qmp_to_opts(self, obj):
731        obj = self.flatten_qmp_object(obj)
732        output_list = []
733        for key in obj:
734            output_list += [key + '=' + obj[key]]
735        return ','.join(output_list)
736
737    def get_qmp_events_filtered(self, wait=60.0):
738        result = []
739        for ev in self.get_qmp_events(wait=wait):
740            result.append(filter_qmp_event(ev))
741        return result
742
743    def qmp_log(self, cmd, filters=(), indent=None, **kwargs):
744        full_cmd = OrderedDict((
745            ("execute", cmd),
746            ("arguments", ordered_qmp(kwargs))
747        ))
748        log(full_cmd, filters, indent=indent)
749        result = self.qmp(cmd, **kwargs)
750        log(result, filters, indent=indent)
751        return result
752
753    # Returns None on success, and an error string on failure
754    def run_job(self, job, auto_finalize=True, auto_dismiss=False,
755                pre_finalize=None, cancel=False, wait=60.0):
756        """
757        run_job moves a job from creation through to dismissal.
758
759        :param job: String. ID of recently-launched job
760        :param auto_finalize: Bool. True if the job was launched with
761                              auto_finalize. Defaults to True.
762        :param auto_dismiss: Bool. True if the job was launched with
763                             auto_dismiss=True. Defaults to False.
764        :param pre_finalize: Callback. A callable that takes no arguments to be
765                             invoked prior to issuing job-finalize, if any.
766        :param cancel: Bool. When true, cancels the job after the pre_finalize
767                       callback.
768        :param wait: Float. Timeout value specifying how long to wait for any
769                     event, in seconds. Defaults to 60.0.
770        """
771        match_device = {'data': {'device': job}}
772        match_id = {'data': {'id': job}}
773        events = [
774            ('BLOCK_JOB_COMPLETED', match_device),
775            ('BLOCK_JOB_CANCELLED', match_device),
776            ('BLOCK_JOB_ERROR', match_device),
777            ('BLOCK_JOB_READY', match_device),
778            ('BLOCK_JOB_PENDING', match_id),
779            ('JOB_STATUS_CHANGE', match_id)
780        ]
781        error = None
782        while True:
783            ev = filter_qmp_event(self.events_wait(events, timeout=wait))
784            if ev['event'] != 'JOB_STATUS_CHANGE':
785                log(ev)
786                continue
787            status = ev['data']['status']
788            if status == 'aborting':
789                result = self.qmp('query-jobs')
790                for j in result['return']:
791                    if j['id'] == job:
792                        error = j['error']
793                        log('Job failed: %s' % (j['error']))
794            elif status == 'ready':
795                self.qmp_log('job-complete', id=job)
796            elif status == 'pending' and not auto_finalize:
797                if pre_finalize:
798                    pre_finalize()
799                if cancel:
800                    self.qmp_log('job-cancel', id=job)
801                else:
802                    self.qmp_log('job-finalize', id=job)
803            elif status == 'concluded' and not auto_dismiss:
804                self.qmp_log('job-dismiss', id=job)
805            elif status == 'null':
806                return error
807
808    # Returns None on success, and an error string on failure
809    def blockdev_create(self, options, job_id='job0', filters=None):
810        if filters is None:
811            filters = [filter_qmp_testfiles]
812        result = self.qmp_log('blockdev-create', filters=filters,
813                              job_id=job_id, options=options)
814
815        if 'return' in result:
816            assert result['return'] == {}
817            job_result = self.run_job(job_id)
818        else:
819            job_result = result['error']
820
821        log("")
822        return job_result
823
824    def enable_migration_events(self, name):
825        log('Enabling migration QMP events on %s...' % name)
826        log(self.qmp('migrate-set-capabilities', capabilities=[
827            {
828                'capability': 'events',
829                'state': True
830            }
831        ]))
832
833    def wait_migration(self, expect_runstate: Optional[str]) -> bool:
834        while True:
835            event = self.event_wait('MIGRATION')
836            # We use the default timeout, and with a timeout, event_wait()
837            # never returns None
838            assert event
839
840            log(event, filters=[filter_qmp_event])
841            if event['data']['status'] in ('completed', 'failed'):
842                break
843
844        if event['data']['status'] == 'completed':
845            # The event may occur in finish-migrate, so wait for the expected
846            # post-migration runstate
847            runstate = None
848            while runstate != expect_runstate:
849                runstate = self.qmp('query-status')['return']['status']
850            return True
851        else:
852            return False
853
854    def node_info(self, node_name):
855        nodes = self.qmp('query-named-block-nodes')
856        for x in nodes['return']:
857            if x['node-name'] == node_name:
858                return x
859        return None
860
861    def query_bitmaps(self):
862        res = self.qmp("query-named-block-nodes")
863        return {device['node-name']: device['dirty-bitmaps']
864                for device in res['return'] if 'dirty-bitmaps' in device}
865
866    def get_bitmap(self, node_name, bitmap_name, recording=None, bitmaps=None):
867        """
868        get a specific bitmap from the object returned by query_bitmaps.
869        :param recording: If specified, filter results by the specified value.
870        :param bitmaps: If specified, use it instead of call query_bitmaps()
871        """
872        if bitmaps is None:
873            bitmaps = self.query_bitmaps()
874
875        for bitmap in bitmaps[node_name]:
876            if bitmap.get('name', '') == bitmap_name:
877                if recording is None or bitmap.get('recording') == recording:
878                    return bitmap
879        return None
880
881    def check_bitmap_status(self, node_name, bitmap_name, fields):
882        ret = self.get_bitmap(node_name, bitmap_name)
883
884        return fields.items() <= ret.items()
885
886    def assert_block_path(self, root, path, expected_node, graph=None):
887        """
888        Check whether the node under the given path in the block graph
889        is @expected_node.
890
891        @root is the node name of the node where the @path is rooted.
892
893        @path is a string that consists of child names separated by
894        slashes.  It must begin with a slash.
895
896        Examples for @root + @path:
897          - root="qcow2-node", path="/backing/file"
898          - root="quorum-node", path="/children.2/file"
899
900        Hypothetically, @path could be empty, in which case it would
901        point to @root.  However, in practice this case is not useful
902        and hence not allowed.
903
904        @expected_node may be None.  (All elements of the path but the
905        leaf must still exist.)
906
907        @graph may be None or the result of an x-debug-query-block-graph
908        call that has already been performed.
909        """
910        if graph is None:
911            graph = self.qmp('x-debug-query-block-graph')['return']
912
913        iter_path = iter(path.split('/'))
914
915        # Must start with a /
916        assert next(iter_path) == ''
917
918        node = next((node for node in graph['nodes'] if node['name'] == root),
919                    None)
920
921        # An empty @path is not allowed, so the root node must be present
922        assert node is not None, 'Root node %s not found' % root
923
924        for child_name in iter_path:
925            assert node is not None, 'Cannot follow path %s%s' % (root, path)
926
927            try:
928                node_id = next(edge['child'] for edge in graph['edges']
929                               if (edge['parent'] == node['id'] and
930                                   edge['name'] == child_name))
931
932                node = next(node for node in graph['nodes']
933                            if node['id'] == node_id)
934
935            except StopIteration:
936                node = None
937
938        if node is None:
939            assert expected_node is None, \
940                   'No node found under %s (but expected %s)' % \
941                   (path, expected_node)
942        else:
943            assert node['name'] == expected_node, \
944                   'Found node %s under %s (but expected %s)' % \
945                   (node['name'], path, expected_node)
946
947index_re = re.compile(r'([^\[]+)\[([^\]]+)\]')
948
949class QMPTestCase(unittest.TestCase):
950    '''Abstract base class for QMP test cases'''
951
952    def __init__(self, *args, **kwargs):
953        super().__init__(*args, **kwargs)
954        # Many users of this class set a VM property we rely on heavily
955        # in the methods below.
956        self.vm = None
957
958    def dictpath(self, d, path):
959        '''Traverse a path in a nested dict'''
960        for component in path.split('/'):
961            m = index_re.match(component)
962            if m:
963                component, idx = m.groups()
964                idx = int(idx)
965
966            if not isinstance(d, dict) or component not in d:
967                self.fail(f'failed path traversal for "{path}" in "{d}"')
968            d = d[component]
969
970            if m:
971                if not isinstance(d, list):
972                    self.fail(f'path component "{component}" in "{path}" '
973                              f'is not a list in "{d}"')
974                try:
975                    d = d[idx]
976                except IndexError:
977                    self.fail(f'invalid index "{idx}" in path "{path}" '
978                              f'in "{d}"')
979        return d
980
981    def assert_qmp_absent(self, d, path):
982        try:
983            result = self.dictpath(d, path)
984        except AssertionError:
985            return
986        self.fail('path "%s" has value "%s"' % (path, str(result)))
987
988    def assert_qmp(self, d, path, value):
989        '''Assert that the value for a specific path in a QMP dict
990           matches.  When given a list of values, assert that any of
991           them matches.'''
992
993        result = self.dictpath(d, path)
994
995        # [] makes no sense as a list of valid values, so treat it as
996        # an actual single value.
997        if isinstance(value, list) and value != []:
998            for v in value:
999                if result == v:
1000                    return
1001            self.fail('no match for "%s" in %s' % (str(result), str(value)))
1002        else:
1003            self.assertEqual(result, value,
1004                             '"%s" is "%s", expected "%s"'
1005                             % (path, str(result), str(value)))
1006
1007    def assert_no_active_block_jobs(self):
1008        result = self.vm.qmp('query-block-jobs')
1009        self.assert_qmp(result, 'return', [])
1010
1011    def assert_has_block_node(self, node_name=None, file_name=None):
1012        """Issue a query-named-block-nodes and assert node_name and/or
1013        file_name is present in the result"""
1014        def check_equal_or_none(a, b):
1015            return a is None or b is None or a == b
1016        assert node_name or file_name
1017        result = self.vm.qmp('query-named-block-nodes')
1018        for x in result["return"]:
1019            if check_equal_or_none(x.get("node-name"), node_name) and \
1020                    check_equal_or_none(x.get("file"), file_name):
1021                return
1022        self.fail("Cannot find %s %s in result:\n%s" %
1023                  (node_name, file_name, result))
1024
1025    def assert_json_filename_equal(self, json_filename, reference):
1026        '''Asserts that the given filename is a json: filename and that its
1027           content is equal to the given reference object'''
1028        self.assertEqual(json_filename[:5], 'json:')
1029        self.assertEqual(
1030            self.vm.flatten_qmp_object(json.loads(json_filename[5:])),
1031            self.vm.flatten_qmp_object(reference)
1032        )
1033
1034    def cancel_and_wait(self, drive='drive0', force=False,
1035                        resume=False, wait=60.0):
1036        '''Cancel a block job and wait for it to finish, returning the event'''
1037        result = self.vm.qmp('block-job-cancel', device=drive, force=force)
1038        self.assert_qmp(result, 'return', {})
1039
1040        if resume:
1041            self.vm.resume_drive(drive)
1042
1043        cancelled = False
1044        result = None
1045        while not cancelled:
1046            for event in self.vm.get_qmp_events(wait=wait):
1047                if event['event'] == 'BLOCK_JOB_COMPLETED' or \
1048                   event['event'] == 'BLOCK_JOB_CANCELLED':
1049                    self.assert_qmp(event, 'data/device', drive)
1050                    result = event
1051                    cancelled = True
1052                elif event['event'] == 'JOB_STATUS_CHANGE':
1053                    self.assert_qmp(event, 'data/id', drive)
1054
1055
1056        self.assert_no_active_block_jobs()
1057        return result
1058
1059    def wait_until_completed(self, drive='drive0', check_offset=True,
1060                             wait=60.0, error=None):
1061        '''Wait for a block job to finish, returning the event'''
1062        while True:
1063            for event in self.vm.get_qmp_events(wait=wait):
1064                if event['event'] == 'BLOCK_JOB_COMPLETED':
1065                    self.assert_qmp(event, 'data/device', drive)
1066                    if error is None:
1067                        self.assert_qmp_absent(event, 'data/error')
1068                        if check_offset:
1069                            self.assert_qmp(event, 'data/offset',
1070                                            event['data']['len'])
1071                    else:
1072                        self.assert_qmp(event, 'data/error', error)
1073                    self.assert_no_active_block_jobs()
1074                    return event
1075                if event['event'] == 'JOB_STATUS_CHANGE':
1076                    self.assert_qmp(event, 'data/id', drive)
1077
1078    def wait_ready(self, drive='drive0'):
1079        """Wait until a BLOCK_JOB_READY event, and return the event."""
1080        return self.vm.events_wait([
1081            ('BLOCK_JOB_READY',
1082             {'data': {'type': 'mirror', 'device': drive}}),
1083            ('BLOCK_JOB_READY',
1084             {'data': {'type': 'commit', 'device': drive}})
1085        ])
1086
1087    def wait_ready_and_cancel(self, drive='drive0'):
1088        self.wait_ready(drive=drive)
1089        event = self.cancel_and_wait(drive=drive)
1090        self.assertEqual(event['event'], 'BLOCK_JOB_COMPLETED')
1091        self.assert_qmp(event, 'data/type', 'mirror')
1092        self.assert_qmp(event, 'data/offset', event['data']['len'])
1093
1094    def complete_and_wait(self, drive='drive0', wait_ready=True,
1095                          completion_error=None):
1096        '''Complete a block job and wait for it to finish'''
1097        if wait_ready:
1098            self.wait_ready(drive=drive)
1099
1100        result = self.vm.qmp('block-job-complete', device=drive)
1101        self.assert_qmp(result, 'return', {})
1102
1103        event = self.wait_until_completed(drive=drive, error=completion_error)
1104        self.assertTrue(event['data']['type'] in ['mirror', 'commit'])
1105
1106    def pause_wait(self, job_id='job0'):
1107        with Timeout(3, "Timeout waiting for job to pause"):
1108            while True:
1109                result = self.vm.qmp('query-block-jobs')
1110                found = False
1111                for job in result['return']:
1112                    if job['device'] == job_id:
1113                        found = True
1114                        if job['paused'] and not job['busy']:
1115                            return job
1116                        break
1117                assert found
1118
1119    def pause_job(self, job_id='job0', wait=True):
1120        result = self.vm.qmp('block-job-pause', device=job_id)
1121        self.assert_qmp(result, 'return', {})
1122        if wait:
1123            return self.pause_wait(job_id)
1124        return result
1125
1126    def case_skip(self, reason):
1127        '''Skip this test case'''
1128        case_notrun(reason)
1129        self.skipTest(reason)
1130
1131
1132def notrun(reason):
1133    '''Skip this test suite'''
1134    # Each test in qemu-iotests has a number ("seq")
1135    seq = os.path.basename(sys.argv[0])
1136
1137    with open('%s/%s.notrun' % (output_dir, seq), 'w', encoding='utf-8') \
1138            as outfile:
1139        outfile.write(reason + '\n')
1140    logger.warning("%s not run: %s", seq, reason)
1141    sys.exit(0)
1142
1143def case_notrun(reason):
1144    '''Mark this test case as not having been run (without actually
1145    skipping it, that is left to the caller).  See
1146    QMPTestCase.case_skip() for a variant that actually skips the
1147    current test case.'''
1148
1149    # Each test in qemu-iotests has a number ("seq")
1150    seq = os.path.basename(sys.argv[0])
1151
1152    with open('%s/%s.casenotrun' % (output_dir, seq), 'a', encoding='utf-8') \
1153            as outfile:
1154        outfile.write('    [case not run] ' + reason + '\n')
1155
1156def _verify_image_format(supported_fmts: Sequence[str] = (),
1157                         unsupported_fmts: Sequence[str] = ()) -> None:
1158    if 'generic' in supported_fmts and \
1159            os.environ.get('IMGFMT_GENERIC', 'true') == 'true':
1160        # similar to
1161        #   _supported_fmt generic
1162        # for bash tests
1163        supported_fmts = ()
1164
1165    not_sup = supported_fmts and (imgfmt not in supported_fmts)
1166    if not_sup or (imgfmt in unsupported_fmts):
1167        notrun('not suitable for this image format: %s' % imgfmt)
1168
1169    if imgfmt == 'luks':
1170        verify_working_luks()
1171
1172def _verify_protocol(supported: Sequence[str] = (),
1173                     unsupported: Sequence[str] = ()) -> None:
1174    assert not (supported and unsupported)
1175
1176    if 'generic' in supported:
1177        return
1178
1179    not_sup = supported and (imgproto not in supported)
1180    if not_sup or (imgproto in unsupported):
1181        notrun('not suitable for this protocol: %s' % imgproto)
1182
1183def _verify_platform(supported: Sequence[str] = (),
1184                     unsupported: Sequence[str] = ()) -> None:
1185    if any((sys.platform.startswith(x) for x in unsupported)):
1186        notrun('not suitable for this OS: %s' % sys.platform)
1187
1188    if supported:
1189        if not any((sys.platform.startswith(x) for x in supported)):
1190            notrun('not suitable for this OS: %s' % sys.platform)
1191
1192def _verify_cache_mode(supported_cache_modes: Sequence[str] = ()) -> None:
1193    if supported_cache_modes and (cachemode not in supported_cache_modes):
1194        notrun('not suitable for this cache mode: %s' % cachemode)
1195
1196def _verify_aio_mode(supported_aio_modes: Sequence[str] = ()) -> None:
1197    if supported_aio_modes and (aiomode not in supported_aio_modes):
1198        notrun('not suitable for this aio mode: %s' % aiomode)
1199
1200def _verify_formats(required_formats: Sequence[str] = ()) -> None:
1201    usf_list = list(set(required_formats) - set(supported_formats()))
1202    if usf_list:
1203        notrun(f'formats {usf_list} are not whitelisted')
1204
1205
1206def _verify_virtio_blk() -> None:
1207    out = qemu_pipe('-M', 'none', '-device', 'help')
1208    if 'virtio-blk' not in out:
1209        notrun('Missing virtio-blk in QEMU binary')
1210
1211def _verify_virtio_scsi_pci_or_ccw() -> None:
1212    out = qemu_pipe('-M', 'none', '-device', 'help')
1213    if 'virtio-scsi-pci' not in out and 'virtio-scsi-ccw' not in out:
1214        notrun('Missing virtio-scsi-pci or virtio-scsi-ccw in QEMU binary')
1215
1216
1217def supports_quorum():
1218    return 'quorum' in qemu_img_pipe('--help')
1219
1220def verify_quorum():
1221    '''Skip test suite if quorum support is not available'''
1222    if not supports_quorum():
1223        notrun('quorum support missing')
1224
1225def has_working_luks() -> Tuple[bool, str]:
1226    """
1227    Check whether our LUKS driver can actually create images
1228    (this extends to LUKS encryption for qcow2).
1229
1230    If not, return the reason why.
1231    """
1232
1233    img_file = f'{test_dir}/luks-test.luks'
1234    (output, status) = \
1235        qemu_img_pipe_and_status('create', '-f', 'luks',
1236                                 '--object', luks_default_secret_object,
1237                                 '-o', luks_default_key_secret_opt,
1238                                 '-o', 'iter-time=10',
1239                                 img_file, '1G')
1240    try:
1241        os.remove(img_file)
1242    except OSError:
1243        pass
1244
1245    if status != 0:
1246        reason = output
1247        for line in output.splitlines():
1248            if img_file + ':' in line:
1249                reason = line.split(img_file + ':', 1)[1].strip()
1250                break
1251
1252        return (False, reason)
1253    else:
1254        return (True, '')
1255
1256def verify_working_luks():
1257    """
1258    Skip test suite if LUKS does not work
1259    """
1260    (working, reason) = has_working_luks()
1261    if not working:
1262        notrun(reason)
1263
1264def qemu_pipe(*args: str) -> str:
1265    """
1266    Run qemu with an option to print something and exit (e.g. a help option).
1267
1268    :return: QEMU's stdout output.
1269    """
1270    full_args = [qemu_prog] + qemu_opts + list(args)
1271    output, _ = qemu_tool_pipe_and_status('qemu', full_args)
1272    return output
1273
1274def supported_formats(read_only=False):
1275    '''Set 'read_only' to True to check ro-whitelist
1276       Otherwise, rw-whitelist is checked'''
1277
1278    if not hasattr(supported_formats, "formats"):
1279        supported_formats.formats = {}
1280
1281    if read_only not in supported_formats.formats:
1282        format_message = qemu_pipe("-drive", "format=help")
1283        line = 1 if read_only else 0
1284        supported_formats.formats[read_only] = \
1285            format_message.splitlines()[line].split(":")[1].split()
1286
1287    return supported_formats.formats[read_only]
1288
1289def skip_if_unsupported(required_formats=(), read_only=False):
1290    '''Skip Test Decorator
1291       Runs the test if all the required formats are whitelisted'''
1292    def skip_test_decorator(func):
1293        def func_wrapper(test_case: QMPTestCase, *args: List[Any],
1294                         **kwargs: Dict[str, Any]) -> None:
1295            if callable(required_formats):
1296                fmts = required_formats(test_case)
1297            else:
1298                fmts = required_formats
1299
1300            usf_list = list(set(fmts) - set(supported_formats(read_only)))
1301            if usf_list:
1302                msg = f'{test_case}: formats {usf_list} are not whitelisted'
1303                test_case.case_skip(msg)
1304            else:
1305                func(test_case, *args, **kwargs)
1306        return func_wrapper
1307    return skip_test_decorator
1308
1309def skip_for_formats(formats: Sequence[str] = ()) \
1310    -> Callable[[Callable[[QMPTestCase, List[Any], Dict[str, Any]], None]],
1311                Callable[[QMPTestCase, List[Any], Dict[str, Any]], None]]:
1312    '''Skip Test Decorator
1313       Skips the test for the given formats'''
1314    def skip_test_decorator(func):
1315        def func_wrapper(test_case: QMPTestCase, *args: List[Any],
1316                         **kwargs: Dict[str, Any]) -> None:
1317            if imgfmt in formats:
1318                msg = f'{test_case}: Skipped for format {imgfmt}'
1319                test_case.case_skip(msg)
1320            else:
1321                func(test_case, *args, **kwargs)
1322        return func_wrapper
1323    return skip_test_decorator
1324
1325def skip_if_user_is_root(func):
1326    '''Skip Test Decorator
1327       Runs the test only without root permissions'''
1328    def func_wrapper(*args, **kwargs):
1329        if os.getuid() == 0:
1330            case_notrun('{}: cannot be run as root'.format(args[0]))
1331            return None
1332        else:
1333            return func(*args, **kwargs)
1334    return func_wrapper
1335
1336# We need to filter out the time taken from the output so that
1337# qemu-iotest can reliably diff the results against master output,
1338# and hide skipped tests from the reference output.
1339
1340class ReproducibleTestResult(unittest.TextTestResult):
1341    def addSkip(self, test, reason):
1342        # Same as TextTestResult, but print dot instead of "s"
1343        unittest.TestResult.addSkip(self, test, reason)
1344        if self.showAll:
1345            self.stream.writeln("skipped {0!r}".format(reason))
1346        elif self.dots:
1347            self.stream.write(".")
1348            self.stream.flush()
1349
1350class ReproducibleStreamWrapper:
1351    def __init__(self, stream: TextIO):
1352        self.stream = stream
1353
1354    def __getattr__(self, attr):
1355        if attr in ('stream', '__getstate__'):
1356            raise AttributeError(attr)
1357        return getattr(self.stream, attr)
1358
1359    def write(self, arg=None):
1360        arg = re.sub(r'Ran (\d+) tests? in [\d.]+s', r'Ran \1 tests', arg)
1361        arg = re.sub(r' \(skipped=\d+\)', r'', arg)
1362        self.stream.write(arg)
1363
1364class ReproducibleTestRunner(unittest.TextTestRunner):
1365    def __init__(self, stream: Optional[TextIO] = None,
1366                 resultclass: Type[unittest.TestResult] =
1367                 ReproducibleTestResult,
1368                 **kwargs: Any) -> None:
1369        rstream = ReproducibleStreamWrapper(stream or sys.stdout)
1370        super().__init__(stream=rstream,           # type: ignore
1371                         descriptions=True,
1372                         resultclass=resultclass,
1373                         **kwargs)
1374
1375def execute_unittest(argv: List[str], debug: bool = False) -> None:
1376    """Executes unittests within the calling module."""
1377
1378    # Some tests have warnings, especially ResourceWarnings for unclosed
1379    # files and sockets.  Ignore them for now to ensure reproducibility of
1380    # the test output.
1381    unittest.main(argv=argv,
1382                  testRunner=ReproducibleTestRunner,
1383                  verbosity=2 if debug else 1,
1384                  warnings=None if sys.warnoptions else 'ignore')
1385
1386def execute_setup_common(supported_fmts: Sequence[str] = (),
1387                         supported_platforms: Sequence[str] = (),
1388                         supported_cache_modes: Sequence[str] = (),
1389                         supported_aio_modes: Sequence[str] = (),
1390                         unsupported_fmts: Sequence[str] = (),
1391                         supported_protocols: Sequence[str] = (),
1392                         unsupported_protocols: Sequence[str] = (),
1393                         required_fmts: Sequence[str] = ()) -> bool:
1394    """
1395    Perform necessary setup for either script-style or unittest-style tests.
1396
1397    :return: Bool; Whether or not debug mode has been requested via the CLI.
1398    """
1399    # Note: Python 3.6 and pylint do not like 'Collection' so use 'Sequence'.
1400
1401    debug = '-d' in sys.argv
1402    if debug:
1403        sys.argv.remove('-d')
1404    logging.basicConfig(level=(logging.DEBUG if debug else logging.WARN))
1405
1406    _verify_image_format(supported_fmts, unsupported_fmts)
1407    _verify_protocol(supported_protocols, unsupported_protocols)
1408    _verify_platform(supported=supported_platforms)
1409    _verify_cache_mode(supported_cache_modes)
1410    _verify_aio_mode(supported_aio_modes)
1411    _verify_formats(required_fmts)
1412    _verify_virtio_blk()
1413
1414    return debug
1415
1416def execute_test(*args, test_function=None, **kwargs):
1417    """Run either unittest or script-style tests."""
1418
1419    debug = execute_setup_common(*args, **kwargs)
1420    if not test_function:
1421        execute_unittest(sys.argv, debug)
1422    else:
1423        test_function()
1424
1425def activate_logging():
1426    """Activate iotests.log() output to stdout for script-style tests."""
1427    handler = logging.StreamHandler(stream=sys.stdout)
1428    formatter = logging.Formatter('%(message)s')
1429    handler.setFormatter(formatter)
1430    test_logger.addHandler(handler)
1431    test_logger.setLevel(logging.INFO)
1432    test_logger.propagate = False
1433
1434# This is called from script-style iotests without a single point of entry
1435def script_initialize(*args, **kwargs):
1436    """Initialize script-style tests without running any tests."""
1437    activate_logging()
1438    execute_setup_common(*args, **kwargs)
1439
1440# This is called from script-style iotests with a single point of entry
1441def script_main(test_function, *args, **kwargs):
1442    """Run script-style tests outside of the unittest framework"""
1443    activate_logging()
1444    execute_test(*args, test_function=test_function, **kwargs)
1445
1446# This is called from unittest style iotests
1447def main(*args, **kwargs):
1448    """Run tests using the unittest framework"""
1449    execute_test(*args, **kwargs)
1450