1from __future__ import print_function
2
3import unittest
4import dbus
5import subprocess
6import os
7import time
8import re
9import sys
10from datetime import datetime
11from enum import Enum
12from systemd import journal
13
14if sys.version_info.major == 3 and sys.version_info.minor >= 3:
15    from time import monotonic
16else:
17    from monotonic import monotonic
18
19import gi
20gi.require_version('GUdev', '1.0')
21from gi.repository import GUdev
22
23test_devs = None
24FLIGHT_RECORD_FILE = "flight_record.log"
25
26def run_command(command):
27    res = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE,
28                           stderr=subprocess.PIPE)
29
30    out, err = res.communicate()
31    if res.returncode != 0:
32        output = out.decode().strip() + "\n\n" + err.decode().strip()
33    else:
34        output = out.decode().strip()
35    return (res.returncode, output)
36
37def get_call_long(call):
38    def call_long(*args, **kwargs):
39        """Do an async call with a very long timeout (unless specified otherwise)"""
40        kwargs['timeout'] = 100  # seconds
41        return call(*args, **kwargs)
42
43    return call_long
44
45
46def get_version_from_lsb():
47    ret, out = run_command("lsb_release -rs")
48    if ret != 0:
49        raise RuntimeError("Cannot get distro version from lsb_release output: '%s'" % out)
50
51    return out.split(".")[0]
52
53
54def get_version_from_pretty_name(pretty_name):
55    """ Try to get distro and version from 'OperatingSystemPrettyName'
56        hostname property.
57
58        It should look like this:
59         - "Debian GNU/Linux 9 (stretch)"
60         - "Fedora 27 (Workstation Edition)"
61         - "CentOS Linux 7 (Core)"
62
63        So just return first word as distro and first number as version.
64    """
65    distro = pretty_name.split()[0].lower()
66    match = re.search(r"\d+", pretty_name)
67    if match is not None:
68        version = match.group(0)
69    else:
70        version = get_version_from_lsb()
71
72    return (distro, version)
73
74
75def get_version():
76    """ Try to get distro and version
77    """
78
79    bus = dbus.SystemBus()
80
81    # get information about the distribution from systemd (hostname1)
82    sys_info = bus.get_object("org.freedesktop.hostname1", "/org/freedesktop/hostname1")
83    cpe = str(sys_info.Get("org.freedesktop.hostname1", "OperatingSystemCPEName", dbus_interface=dbus.PROPERTIES_IFACE))
84
85    if cpe:
86        # 2nd to 4th fields from e.g. "cpe:/o:fedoraproject:fedora:25" or "cpe:/o:redhat:enterprise_linux:7.3:GA:server"
87        _project, distro, version = tuple(cpe.split(":")[2:5])
88        # we want just the major version, so remove all decimal places (if any)
89        version = str(int(float(version)))
90    else:
91        pretty_name = str(sys_info.Get("org.freedesktop.hostname1", "OperatingSystemPrettyName", dbus_interface=dbus.PROPERTIES_IFACE))
92        distro, version = get_version_from_pretty_name(pretty_name)
93
94    return (distro, version)
95
96
97def skip_on(skip_on_distros, skip_on_version="", reason=""):
98    """A function returning a decorator to skip some test on a given distribution-version combination
99
100    :param skip_on_distros: distro(s) to skip the test on
101    :type skip_on_distros: str or tuple of str
102    :param str skip_on_version: version of distro(s) to skip the tests on (only
103                                checked on distribution match)
104
105    """
106    if isinstance(skip_on_distros, str):
107        skip_on_distros = (skip_on_distros,)
108
109    distro, version = get_version()
110
111    def decorator(func):
112        if distro in skip_on_distros and (not skip_on_version or skip_on_version == version):
113            msg = "not supported on this distribution in this version" + (": %s" % reason if reason else "")
114            return unittest.skip(msg)(func)
115        else:
116            return func
117
118    return decorator
119
120def unstable_test(test):
121    """Decorator for unstable tests
122
123    Failures of tests decorated with this decorator are silently ignored unless
124    the ``UNSTABLE_TESTS_FATAL`` environment variable is defined.
125    """
126
127    def decorated_test(*args):
128        try:
129            test(*args)
130        except unittest.SkipTest:
131            # make sure skipped tests are just skipped as usual
132            raise
133        except Exception as e:
134            # and swallow everything else, just report a failure of an unstable
135            # test, unless told otherwise
136            if "UNSTABLE_TESTS_FATAL" in os.environ:
137                raise
138            print("unstable-fail: Ignoring exception '%s'\n" % e, end="", file=sys.stderr)
139
140    return decorated_test
141
142
143class DBusProperty(object):
144
145    TIMEOUT = 5
146
147    def __init__(self, obj, iface, prop):
148        self.obj = obj
149        self.iface = iface
150        self.prop = prop
151
152        self._value = None
153
154    @property
155    def value(self):
156        if self._value is None:
157            self._update_value()
158        return self._value
159
160    def _update_value(self):
161        self._value = self.obj.Get(self.iface, self.prop, dbus_interface=dbus.PROPERTIES_IFACE)
162
163    def _check(self, timeout, check_fn):
164        for _ in range(int(timeout / 0.5)):
165            try:
166                self._update_value()
167                if check_fn(self.value):
168                    return True
169            except Exception:
170                # ignore all exceptions -- they might be result of property
171                # not having the expected type (e.g. 'None' when checking for len)
172                pass
173            time.sleep(0.5)
174
175        return False
176
177    def assertEqual(self, value, timeout=TIMEOUT, getter=None):
178        if getter is not None:
179            check_fn = lambda x: getter(x) == value
180        else:
181            check_fn = lambda x: x == value
182        ret = self._check(timeout, check_fn)
183
184        if not ret:
185            if getter is not None:
186                raise AssertionError('%s != %s' % (getter(self._value), value))
187            else:
188                raise AssertionError('%s != %s' % (self._value, value))
189
190    def assertNotEqual(self, value, timeout=TIMEOUT, getter=None):
191        if getter is not None:
192            check_fn = lambda x: getter(x) != value
193        else:
194            check_fn = lambda x: x != value
195        ret = self._check(timeout, check_fn)
196
197        if not ret:
198            if getter is not None:
199                raise AssertionError('%s == %s' % (getter(self._value), value))
200            else:
201                raise AssertionError('%s == %s' % (self._value, value))
202
203    def assertAlmostEqual(self, value, delta, timeout=TIMEOUT, getter=None):
204        if getter is not None:
205            check_fn = lambda x: abs(getter(x) - value) <= delta
206        else:
207            check_fn = lambda x: abs(x - value) <= delta
208        ret = self._check(timeout, check_fn)
209
210        if not ret:
211            if getter is not None:
212                raise AssertionError('%s is not almost equal to %s (delta = %s)' % (getter(self._value),
213                                                                                    value, delta))
214            else:
215                raise AssertionError('%s is not almost equal to %s (delta = %s)' % (self._value,
216                                                                                    value, delta))
217
218
219    def assertGreater(self, value, timeout=TIMEOUT):
220        check_fn = lambda x: x > value
221        ret = self._check(timeout, check_fn)
222
223        if not ret:
224            raise AssertionError('%s is not greater than %s' % (self._value, value))
225
226    def assertLess(self, value, timeout=TIMEOUT):
227        check_fn = lambda x: x < value
228        ret = self._check(timeout, check_fn)
229
230        if not ret:
231            raise AssertionError('%s is not less than %s' % (self._value, value))
232
233    def assertIn(self, lst, timeout=TIMEOUT):
234        check_fn = lambda x: x in lst
235        ret = self._check(timeout, check_fn)
236
237        if not ret:
238            raise AssertionError('%s not found in %s' % (self._value, lst))
239
240    def assertNotIn(self, lst, timeout=TIMEOUT):
241        check_fn = lambda x: x not in lst
242        ret = self._check(timeout, check_fn)
243
244        if not ret:
245            raise AssertionError('%s unexpectedly found in %s' % (self._value, lst))
246
247    def assertTrue(self, timeout=TIMEOUT):
248        check_fn = lambda x: bool(x)
249        ret = self._check(timeout, check_fn)
250
251        if not ret:
252            raise AssertionError('%s is not true' % self._value)
253
254    def assertFalse(self, timeout=TIMEOUT):
255        check_fn = lambda x: not bool(x)
256        ret = self._check(timeout, check_fn)
257
258        if not ret:
259            raise AssertionError('%s is not false' % self._value)
260
261    def assertIsNone(self, timeout=TIMEOUT):
262        check_fn = lambda x: x is None
263        ret = self._check(timeout, check_fn)
264
265        if not ret:
266            raise AssertionError('%s is not None' % self._value)
267
268    def assertIsNotNone(self, timeout=TIMEOUT):
269        check_fn = lambda x: x is not None
270        ret = self._check(timeout, check_fn)
271
272        if not ret:
273            raise AssertionError('unexpectedly None')
274
275    def assertLen(self, length, timeout=TIMEOUT):
276        check_fn = lambda x: len(x) == length
277        ret = self._check(timeout, check_fn)
278
279        if not ret:
280            if not hasattr(self._value, '__len__'):
281                raise AssertionError('%s has no length' % type(self._value))
282            else:
283                raise AssertionError('Expected length %d, but %s has length %d' % (length,
284                                                                                   self._value,
285                                                                                   len(self._value)))
286    def assertContains(self, member, timeout=TIMEOUT):
287        check_fn = lambda x: member in x
288        ret = self._check(timeout, check_fn)
289
290        if not ret:
291            raise AssertionError('%s does not contain %s' % (self._value, member))
292
293class UdisksTestCase(unittest.TestCase):
294    iface_prefix = None
295    path_prefix = None
296    bus = None
297    vdevs = None
298    distro = (None, None, None)       # (project, distro_name, version)
299    no_options = dbus.Dictionary(signature="sv")
300
301
302    @classmethod
303    def setUpClass(self):
304        self.iface_prefix = 'org.freedesktop.UDisks2'
305        self.path_prefix = '/org/freedesktop/UDisks2'
306        self.bus = dbus.SystemBus()
307
308        self.distro = get_version()
309
310        self._orig_call_async = self.bus.call_async
311        self._orig_call_blocking = self.bus.call_blocking
312        self.bus.call_async = get_call_long(self._orig_call_async)
313        self.bus.call_blocking = get_call_long(self._orig_call_blocking)
314        self.vdevs = test_devs
315        assert len(self.vdevs) > 3;
316
317
318    @classmethod
319    def tearDownClass(self):
320        self.bus.call_async = self._orig_call_async
321        self.bus.call_blocking = self._orig_call_blocking
322
323
324    def run(self, *args):
325        record = []
326        now = datetime.now()
327        now_mono = monotonic()
328        with open(FLIGHT_RECORD_FILE, "a") as record_f:
329            record_f.write("================%s[%0.8f] %s.%s.%s================\n" % (now.strftime('%Y-%m-%d %H:%M:%S'),
330                                                                                     now_mono,
331                                                                                     self.__class__.__module__,
332                                                                                     self.__class__.__name__,
333                                                                                     self._testMethodName))
334            with JournalRecorder("journal", record):
335                with CmdFlightRecorder("udisksctl monitor", ["udisksctl", "monitor"], record):
336                    with CmdFlightRecorder("udevadm monitor", ["udevadm", "monitor"], record):
337                        super(UdisksTestCase, self).run(*args)
338            record_f.write("".join(record))
339        self.udev_settle()
340
341    @classmethod
342    def get_object(self, path_suffix):
343        # if given full path, just use it, otherwise prepend the prefix
344        if path_suffix.startswith(self.path_prefix):
345            path = path_suffix
346        else:
347            path = self.path_prefix + path_suffix
348        try:
349            # self.iface_prefix is the same as the DBus name we acquire
350            obj = self.bus.get_object(self.iface_prefix, path)
351        except:
352            obj = None
353        return obj
354
355    @classmethod
356    def get_interface(self, obj, iface_suffix):
357        """Get interface for the given object either specified by an object path suffix
358        (appended to the common UDisks2 prefix) or given as the object
359        itself.
360
361        :param obj: object to get the interface for
362        :type obj: str or dbus.proxies.ProxyObject
363        :param iface_suffix: suffix appended to the common UDisks2 interface prefix
364        :type iface_suffix: str
365
366        """
367        if isinstance(obj, str):
368            obj = self.get_object(obj)
369        return dbus.Interface(obj, self.iface_prefix + iface_suffix)
370
371
372    @classmethod
373    def get_property(self, obj, iface_suffix, prop):
374        return DBusProperty(obj, self.iface_prefix + iface_suffix, prop)
375
376
377    @classmethod
378    def get_property_raw(self, obj, iface_suffix, prop):
379        res = obj.Get(self.iface_prefix + iface_suffix, prop, dbus_interface=dbus.PROPERTIES_IFACE)
380        return res
381
382    @classmethod
383    def get_device(self, dev_name):
384        """Get block device object for a given device (e.g. "sda")"""
385        dev = self.get_object('/block_devices/' + os.path.basename(dev_name))
386        return dev
387
388    @classmethod
389    def get_drive_name(self, device):
390        """Get drive name for the given block device object"""
391        drive_name = self.get_property_raw(device, '.Block', 'Drive').split('/')[-1]
392        return drive_name
393
394    @classmethod
395    def udev_settle(self):
396        self.run_command('udevadm settle')
397
398    @classmethod
399    def wipe_fs(self, device):
400        for _ in range(10):
401            ret, _out = self.run_command('wipefs -a %s' % device)
402            if ret == 0:
403                return True
404            time.sleep(1)
405
406        return False
407
408    def try_unmount(self, path):
409        """Handle unmount and retry if busy"""
410        for _ in range(10):
411            ret, out = self.run_command('umount %s' % path)
412            # the mount may either be unmounted already or not exist anymore
413            if ret == 0 or "not mounted" in out or "no mount point specified" in out or "mountpoint not found" in out:
414                return
415            if "target is busy" not in out:
416                break
417            time.sleep(0.5)
418        self.fail('Failed to unmount %s: %s' % (path, out))
419
420    @classmethod
421    def read_file(self, filename):
422        with open(filename, 'r') as f:
423            content = f.read()
424        return content
425
426
427    @classmethod
428    def write_file(self, filename, content, ignore_nonexistent=False, binary=False):
429        try:
430            with open(filename, 'wb' if binary else 'w') as f:
431                f.write(content)
432        except OSError as e:
433            if not ignore_nonexistent:
434                raise e
435
436    @classmethod
437    def remove_file(self, filename, ignore_nonexistent=False):
438        try:
439            os.remove(filename)
440        except OSError as e:
441            if not ignore_nonexistent:
442                raise e
443
444    @classmethod
445    def run_command(self, command):
446        return run_command(command)
447
448    @classmethod
449    def module_available(cls, module):
450        ret, _out = cls.run_command('modprobe %s' % module)
451        return ret == 0
452
453    @classmethod
454    def check_module_loaded(self, module):
455        """Tries to load specified module. No checks for extra Manager interface are done.
456           Returns False when module is not available, True when the module initialized
457           successfully, raises an exception otherwise.
458        """
459        manager_obj = self.get_object('/Manager')
460        manager = self.get_interface(manager_obj, '.Manager')
461        try:
462            manager.EnableModule(module, dbus.Boolean(True))
463            return True
464        except dbus.exceptions.DBusException as e:
465            msg = r"Error initializing module '%s': .*\.so: cannot open shared object file: No such file or directory" % module
466            if re.search(msg, e.get_dbus_message()):
467                return False
468            else:
469                raise
470
471    @classmethod
472    def ay_to_str(self, ay):
473        """Convert a bytearray (terminated with '\0') to a string"""
474
475        return ''.join(chr(x) for x in ay[:-1])
476
477    @classmethod
478    def str_to_ay(self, string, terminate=True):
479        """Convert a string to a bytearray (terminated with '\0')"""
480
481        if terminate:
482            string += '\0'
483
484        return dbus.Array([dbus.Byte(ord(c)) for c in string],
485                          signature=dbus.Signature('y'), variant_level=1)
486
487    @classmethod
488    def bytes_to_ay(self, bytes):
489        """Convert Python bytes to a DBus bytearray"""
490
491        return dbus.Array([dbus.Byte(b) for b in bytes],
492                          signature=dbus.Signature('y'), variant_level=1)
493
494    @classmethod
495    def set_udev_properties(self, device, props):
496        """Sets one or more udev properties for the 'device' identified by its serial number.
497           Note that this overwrites previously set properties. Pass props=None to remove
498           the rules.
499
500        :type props: dict
501        """
502        UDISKS_UDEV_RULES = "/run/udev/rules.d/99-udisks_test.rules"
503
504        udev = GUdev.Client()
505        dev = udev.query_by_device_file(device)
506        serial = dev.get_property("ID_SERIAL")
507
508        try:
509            os.makedirs("/run/udev/rules.d/")
510        except OSError:
511            # already exists
512            pass
513
514        if props:
515            rules = ""
516            for i in props:
517                rules += ', ENV{%s}="%s"' % (i, props[i])
518            self.write_file(UDISKS_UDEV_RULES,
519                            'ENV{ID_SERIAL}=="%s"%s\n' % (serial, rules))
520        else:
521            self.remove_file(UDISKS_UDEV_RULES, ignore_nonexistent=True)
522        self.run_command("udevadm control --reload")
523        uevent_path = os.path.join(dev.get_sysfs_path(), "uevent")
524        self.write_file(uevent_path, "change\n")
525        self.udev_settle()
526        # FIXME: need to give udisksd some time to process the uevent
527        time.sleep(1)
528
529    @classmethod
530    def assertHasIface(self, obj, iface):
531        obj_intro = dbus.Interface(obj, "org.freedesktop.DBus.Introspectable")
532        intro_data = obj_intro.Introspect()
533
534        for _ in range(20):
535            if ('interface name="%s"' % iface) in intro_data:
536                return
537            time.sleep(0.5)
538
539        raise AssertionError("Object '%s' has no interface '%s'" % (obj.object_path, iface))
540
541    def assertStartswith(self, val, prefix):
542        if not val.startswith(prefix):
543            raise AssertionError("'%s' does not start with '%s'" % (val, prefix))
544
545
546class FlightRecorder(object):
547    """Context manager for recording data/logs
548
549    This is the abstract implementation that does nothing. Subclasses are
550    expected to override the methods below to actually do something useful.
551
552    """
553
554    def __init__(self, desc):
555        """
556        :param str desc: description of the recorder
557
558        """
559        self._desc = desc
560
561    def _start(self):
562        """Start recording"""
563
564    def _stop(self):
565        """Stop recording"""
566        pass
567
568    def _save(self):
569        """Save the record"""
570        pass
571
572    def __enter__(self):
573        self._start()
574
575    def __exit__(self, exc_type, exc_val, exc_tb):
576        self._stop()
577        self._save()
578
579        # Returning False means that the exception we have potentially been
580        # given as arguments was not handled
581        return False
582
583class CmdFlightRecorder(FlightRecorder):
584    """Flight recorder running a command and gathering its standard and error output"""
585
586    def __init__(self, desc, argv, store):
587        """
588        :param str desc: description of the recorder
589        :param argv: command and arguments to run
590        :type argv: list of str
591        :param store: a list-like object to append the data/logs to
592
593        """
594        super(CmdFlightRecorder, self).__init__(desc)
595        self._argv = argv
596        self._store = store
597        self._proc = None
598
599    def _start(self):
600        self._proc = subprocess.Popen(self._argv, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
601
602    def _stop(self):
603        self._proc.terminate()
604
605    def _save(self):
606        # err is in out (see above)
607        out, _err = self._proc.communicate()
608        rec = '<<<<< ' + self._desc + ' >>>>>' + '\n' + out.decode() + '\n\n'
609        self._store.append(rec)
610
611class JournalRecorder(FlightRecorder):
612    """Flight recorder for gathering logs (journal)"""
613
614    def __init__(self, desc, store):
615        """
616        :param str desc: description of the recorder
617        :param store: a list-like object to append the data/logs to
618
619        """
620        super(JournalRecorder, self).__init__(desc)
621        self._store = store
622        self._started = None
623        self._stopped = None
624
625    def _start(self):
626        self._started = monotonic()
627
628    def _stop(self):
629        self._stopped = monotonic()
630
631    def _save(self):
632        j = journal.Reader(converters={"MESSAGE": lambda x: x.decode(errors="replace")})
633        j.this_boot()
634        j.seek_monotonic(self._started)
635        journal_data = ""
636
637        entry = j.get_next()
638        # entry["__MONOTONIC_TIMESTAMP"] is a tuple of (datetime.timedelta, boot_id)
639        while entry and entry["__MONOTONIC_TIMESTAMP"][0].seconds <= int(self._stopped):
640            if "_COMM" in entry and "_PID" in entry:
641                source = "%s[%d]" % (entry["_COMM"], entry["_PID"])
642            else:
643                source = "kernel"
644            journal_data += "%s[%0.8f] %s: %s\n" % (entry["__REALTIME_TIMESTAMP"].strftime("%H:%M:%S"),
645                                                    entry["__MONOTONIC_TIMESTAMP"][0].total_seconds(),
646                                                    source, entry["MESSAGE"])
647            entry = j.get_next()
648        rec = '<<<<< ' + self._desc + ' >>>>>' + '\n' + journal_data + '\n\n\n'
649        self._store.append(rec)
650
651
652class TestTags(Enum):
653    ALL = "all"               # "default" tag for running all tests
654    SLOW = "slow"             # slow tests
655    UNSTABLE = "unstable"     # randomly failing tests
656    UNSAFE = "unsafe"         # tests that change system configuration
657    NOSTORAGE = "nostorage"   # tests that don't work with storage
658    EXTRADEPS = "extradeps"   # tests that require special configuration and/or device to run
659
660    @classmethod
661    def get_tags(cls):
662        return [t.value for t in cls.__members__.values()]
663
664    @classmethod
665    def get_tag_by_value(cls, value):
666        tag = next((t for t in cls.__members__.values() if t.value == value), None)
667
668        if not tag:
669            raise ValueError('Unknown value "%s"' % value)
670
671        return tag
672
673
674def tag_test(*tags):
675    def decorator(func):
676        func.slow = TestTags.SLOW in tags
677        func.unstable = TestTags.UNSTABLE in tags
678        func.unsafe = TestTags.UNSAFE in tags
679        func.nostorage = TestTags.NOSTORAGE in tags
680        func.extradeps = TestTags.EXTRADEPS in tags
681
682        return func
683
684    return decorator
685