1from __future__ import print_function 2 3import unittest 4import dbus 5import subprocess 6import os 7import time 8import re 9import sys 10from datetime import datetime 11from enum import Enum 12from systemd import journal 13 14if sys.version_info.major == 3 and sys.version_info.minor >= 3: 15 from time import monotonic 16else: 17 from monotonic import monotonic 18 19import gi 20gi.require_version('GUdev', '1.0') 21from gi.repository import GUdev 22 23test_devs = None 24FLIGHT_RECORD_FILE = "flight_record.log" 25 26def run_command(command): 27 res = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, 28 stderr=subprocess.PIPE) 29 30 out, err = res.communicate() 31 if res.returncode != 0: 32 output = out.decode().strip() + "\n\n" + err.decode().strip() 33 else: 34 output = out.decode().strip() 35 return (res.returncode, output) 36 37def get_call_long(call): 38 def call_long(*args, **kwargs): 39 """Do an async call with a very long timeout (unless specified otherwise)""" 40 kwargs['timeout'] = 100 # seconds 41 return call(*args, **kwargs) 42 43 return call_long 44 45 46def get_version_from_lsb(): 47 ret, out = run_command("lsb_release -rs") 48 if ret != 0: 49 raise RuntimeError("Cannot get distro version from lsb_release output: '%s'" % out) 50 51 return out.split(".")[0] 52 53 54def get_version_from_pretty_name(pretty_name): 55 """ Try to get distro and version from 'OperatingSystemPrettyName' 56 hostname property. 57 58 It should look like this: 59 - "Debian GNU/Linux 9 (stretch)" 60 - "Fedora 27 (Workstation Edition)" 61 - "CentOS Linux 7 (Core)" 62 63 So just return first word as distro and first number as version. 64 """ 65 distro = pretty_name.split()[0].lower() 66 match = re.search(r"\d+", pretty_name) 67 if match is not None: 68 version = match.group(0) 69 else: 70 version = get_version_from_lsb() 71 72 return (distro, version) 73 74 75def get_version(): 76 """ Try to get distro and version 77 """ 78 79 bus = dbus.SystemBus() 80 81 # get information about the distribution from systemd (hostname1) 82 sys_info = bus.get_object("org.freedesktop.hostname1", "/org/freedesktop/hostname1") 83 cpe = str(sys_info.Get("org.freedesktop.hostname1", "OperatingSystemCPEName", dbus_interface=dbus.PROPERTIES_IFACE)) 84 85 if cpe: 86 # 2nd to 4th fields from e.g. "cpe:/o:fedoraproject:fedora:25" or "cpe:/o:redhat:enterprise_linux:7.3:GA:server" 87 _project, distro, version = tuple(cpe.split(":")[2:5]) 88 # we want just the major version, so remove all decimal places (if any) 89 version = str(int(float(version))) 90 else: 91 pretty_name = str(sys_info.Get("org.freedesktop.hostname1", "OperatingSystemPrettyName", dbus_interface=dbus.PROPERTIES_IFACE)) 92 distro, version = get_version_from_pretty_name(pretty_name) 93 94 return (distro, version) 95 96 97def skip_on(skip_on_distros, skip_on_version="", reason=""): 98 """A function returning a decorator to skip some test on a given distribution-version combination 99 100 :param skip_on_distros: distro(s) to skip the test on 101 :type skip_on_distros: str or tuple of str 102 :param str skip_on_version: version of distro(s) to skip the tests on (only 103 checked on distribution match) 104 105 """ 106 if isinstance(skip_on_distros, str): 107 skip_on_distros = (skip_on_distros,) 108 109 distro, version = get_version() 110 111 def decorator(func): 112 if distro in skip_on_distros and (not skip_on_version or skip_on_version == version): 113 msg = "not supported on this distribution in this version" + (": %s" % reason if reason else "") 114 return unittest.skip(msg)(func) 115 else: 116 return func 117 118 return decorator 119 120def unstable_test(test): 121 """Decorator for unstable tests 122 123 Failures of tests decorated with this decorator are silently ignored unless 124 the ``UNSTABLE_TESTS_FATAL`` environment variable is defined. 125 """ 126 127 def decorated_test(*args): 128 try: 129 test(*args) 130 except unittest.SkipTest: 131 # make sure skipped tests are just skipped as usual 132 raise 133 except Exception as e: 134 # and swallow everything else, just report a failure of an unstable 135 # test, unless told otherwise 136 if "UNSTABLE_TESTS_FATAL" in os.environ: 137 raise 138 print("unstable-fail: Ignoring exception '%s'\n" % e, end="", file=sys.stderr) 139 140 return decorated_test 141 142 143class DBusProperty(object): 144 145 TIMEOUT = 5 146 147 def __init__(self, obj, iface, prop): 148 self.obj = obj 149 self.iface = iface 150 self.prop = prop 151 152 self._value = None 153 154 @property 155 def value(self): 156 if self._value is None: 157 self._update_value() 158 return self._value 159 160 def _update_value(self): 161 self._value = self.obj.Get(self.iface, self.prop, dbus_interface=dbus.PROPERTIES_IFACE) 162 163 def _check(self, timeout, check_fn): 164 for _ in range(int(timeout / 0.5)): 165 try: 166 self._update_value() 167 if check_fn(self.value): 168 return True 169 except Exception: 170 # ignore all exceptions -- they might be result of property 171 # not having the expected type (e.g. 'None' when checking for len) 172 pass 173 time.sleep(0.5) 174 175 return False 176 177 def assertEqual(self, value, timeout=TIMEOUT, getter=None): 178 if getter is not None: 179 check_fn = lambda x: getter(x) == value 180 else: 181 check_fn = lambda x: x == value 182 ret = self._check(timeout, check_fn) 183 184 if not ret: 185 if getter is not None: 186 raise AssertionError('%s != %s' % (getter(self._value), value)) 187 else: 188 raise AssertionError('%s != %s' % (self._value, value)) 189 190 def assertNotEqual(self, value, timeout=TIMEOUT, getter=None): 191 if getter is not None: 192 check_fn = lambda x: getter(x) != value 193 else: 194 check_fn = lambda x: x != value 195 ret = self._check(timeout, check_fn) 196 197 if not ret: 198 if getter is not None: 199 raise AssertionError('%s == %s' % (getter(self._value), value)) 200 else: 201 raise AssertionError('%s == %s' % (self._value, value)) 202 203 def assertAlmostEqual(self, value, delta, timeout=TIMEOUT, getter=None): 204 if getter is not None: 205 check_fn = lambda x: abs(getter(x) - value) <= delta 206 else: 207 check_fn = lambda x: abs(x - value) <= delta 208 ret = self._check(timeout, check_fn) 209 210 if not ret: 211 if getter is not None: 212 raise AssertionError('%s is not almost equal to %s (delta = %s)' % (getter(self._value), 213 value, delta)) 214 else: 215 raise AssertionError('%s is not almost equal to %s (delta = %s)' % (self._value, 216 value, delta)) 217 218 219 def assertGreater(self, value, timeout=TIMEOUT): 220 check_fn = lambda x: x > value 221 ret = self._check(timeout, check_fn) 222 223 if not ret: 224 raise AssertionError('%s is not greater than %s' % (self._value, value)) 225 226 def assertLess(self, value, timeout=TIMEOUT): 227 check_fn = lambda x: x < value 228 ret = self._check(timeout, check_fn) 229 230 if not ret: 231 raise AssertionError('%s is not less than %s' % (self._value, value)) 232 233 def assertIn(self, lst, timeout=TIMEOUT): 234 check_fn = lambda x: x in lst 235 ret = self._check(timeout, check_fn) 236 237 if not ret: 238 raise AssertionError('%s not found in %s' % (self._value, lst)) 239 240 def assertNotIn(self, lst, timeout=TIMEOUT): 241 check_fn = lambda x: x not in lst 242 ret = self._check(timeout, check_fn) 243 244 if not ret: 245 raise AssertionError('%s unexpectedly found in %s' % (self._value, lst)) 246 247 def assertTrue(self, timeout=TIMEOUT): 248 check_fn = lambda x: bool(x) 249 ret = self._check(timeout, check_fn) 250 251 if not ret: 252 raise AssertionError('%s is not true' % self._value) 253 254 def assertFalse(self, timeout=TIMEOUT): 255 check_fn = lambda x: not bool(x) 256 ret = self._check(timeout, check_fn) 257 258 if not ret: 259 raise AssertionError('%s is not false' % self._value) 260 261 def assertIsNone(self, timeout=TIMEOUT): 262 check_fn = lambda x: x is None 263 ret = self._check(timeout, check_fn) 264 265 if not ret: 266 raise AssertionError('%s is not None' % self._value) 267 268 def assertIsNotNone(self, timeout=TIMEOUT): 269 check_fn = lambda x: x is not None 270 ret = self._check(timeout, check_fn) 271 272 if not ret: 273 raise AssertionError('unexpectedly None') 274 275 def assertLen(self, length, timeout=TIMEOUT): 276 check_fn = lambda x: len(x) == length 277 ret = self._check(timeout, check_fn) 278 279 if not ret: 280 if not hasattr(self._value, '__len__'): 281 raise AssertionError('%s has no length' % type(self._value)) 282 else: 283 raise AssertionError('Expected length %d, but %s has length %d' % (length, 284 self._value, 285 len(self._value))) 286 def assertContains(self, member, timeout=TIMEOUT): 287 check_fn = lambda x: member in x 288 ret = self._check(timeout, check_fn) 289 290 if not ret: 291 raise AssertionError('%s does not contain %s' % (self._value, member)) 292 293class UdisksTestCase(unittest.TestCase): 294 iface_prefix = None 295 path_prefix = None 296 bus = None 297 vdevs = None 298 distro = (None, None, None) # (project, distro_name, version) 299 no_options = dbus.Dictionary(signature="sv") 300 301 302 @classmethod 303 def setUpClass(self): 304 self.iface_prefix = 'org.freedesktop.UDisks2' 305 self.path_prefix = '/org/freedesktop/UDisks2' 306 self.bus = dbus.SystemBus() 307 308 self.distro = get_version() 309 310 self._orig_call_async = self.bus.call_async 311 self._orig_call_blocking = self.bus.call_blocking 312 self.bus.call_async = get_call_long(self._orig_call_async) 313 self.bus.call_blocking = get_call_long(self._orig_call_blocking) 314 self.vdevs = test_devs 315 assert len(self.vdevs) > 3; 316 317 318 @classmethod 319 def tearDownClass(self): 320 self.bus.call_async = self._orig_call_async 321 self.bus.call_blocking = self._orig_call_blocking 322 323 324 def run(self, *args): 325 record = [] 326 now = datetime.now() 327 now_mono = monotonic() 328 with open(FLIGHT_RECORD_FILE, "a") as record_f: 329 record_f.write("================%s[%0.8f] %s.%s.%s================\n" % (now.strftime('%Y-%m-%d %H:%M:%S'), 330 now_mono, 331 self.__class__.__module__, 332 self.__class__.__name__, 333 self._testMethodName)) 334 with JournalRecorder("journal", record): 335 with CmdFlightRecorder("udisksctl monitor", ["udisksctl", "monitor"], record): 336 with CmdFlightRecorder("udevadm monitor", ["udevadm", "monitor"], record): 337 super(UdisksTestCase, self).run(*args) 338 record_f.write("".join(record)) 339 self.udev_settle() 340 341 @classmethod 342 def get_object(self, path_suffix): 343 # if given full path, just use it, otherwise prepend the prefix 344 if path_suffix.startswith(self.path_prefix): 345 path = path_suffix 346 else: 347 path = self.path_prefix + path_suffix 348 try: 349 # self.iface_prefix is the same as the DBus name we acquire 350 obj = self.bus.get_object(self.iface_prefix, path) 351 except: 352 obj = None 353 return obj 354 355 @classmethod 356 def get_interface(self, obj, iface_suffix): 357 """Get interface for the given object either specified by an object path suffix 358 (appended to the common UDisks2 prefix) or given as the object 359 itself. 360 361 :param obj: object to get the interface for 362 :type obj: str or dbus.proxies.ProxyObject 363 :param iface_suffix: suffix appended to the common UDisks2 interface prefix 364 :type iface_suffix: str 365 366 """ 367 if isinstance(obj, str): 368 obj = self.get_object(obj) 369 return dbus.Interface(obj, self.iface_prefix + iface_suffix) 370 371 372 @classmethod 373 def get_property(self, obj, iface_suffix, prop): 374 return DBusProperty(obj, self.iface_prefix + iface_suffix, prop) 375 376 377 @classmethod 378 def get_property_raw(self, obj, iface_suffix, prop): 379 res = obj.Get(self.iface_prefix + iface_suffix, prop, dbus_interface=dbus.PROPERTIES_IFACE) 380 return res 381 382 @classmethod 383 def get_device(self, dev_name): 384 """Get block device object for a given device (e.g. "sda")""" 385 dev = self.get_object('/block_devices/' + os.path.basename(dev_name)) 386 return dev 387 388 @classmethod 389 def get_drive_name(self, device): 390 """Get drive name for the given block device object""" 391 drive_name = self.get_property_raw(device, '.Block', 'Drive').split('/')[-1] 392 return drive_name 393 394 @classmethod 395 def udev_settle(self): 396 self.run_command('udevadm settle') 397 398 @classmethod 399 def wipe_fs(self, device): 400 for _ in range(10): 401 ret, _out = self.run_command('wipefs -a %s' % device) 402 if ret == 0: 403 return True 404 time.sleep(1) 405 406 return False 407 408 def try_unmount(self, path): 409 """Handle unmount and retry if busy""" 410 for _ in range(10): 411 ret, out = self.run_command('umount %s' % path) 412 # the mount may either be unmounted already or not exist anymore 413 if ret == 0 or "not mounted" in out or "no mount point specified" in out or "mountpoint not found" in out: 414 return 415 if "target is busy" not in out: 416 break 417 time.sleep(0.5) 418 self.fail('Failed to unmount %s: %s' % (path, out)) 419 420 @classmethod 421 def read_file(self, filename): 422 with open(filename, 'r') as f: 423 content = f.read() 424 return content 425 426 427 @classmethod 428 def write_file(self, filename, content, ignore_nonexistent=False, binary=False): 429 try: 430 with open(filename, 'wb' if binary else 'w') as f: 431 f.write(content) 432 except OSError as e: 433 if not ignore_nonexistent: 434 raise e 435 436 @classmethod 437 def remove_file(self, filename, ignore_nonexistent=False): 438 try: 439 os.remove(filename) 440 except OSError as e: 441 if not ignore_nonexistent: 442 raise e 443 444 @classmethod 445 def run_command(self, command): 446 return run_command(command) 447 448 @classmethod 449 def module_available(cls, module): 450 ret, _out = cls.run_command('modprobe %s' % module) 451 return ret == 0 452 453 @classmethod 454 def check_module_loaded(self, module): 455 """Tries to load specified module. No checks for extra Manager interface are done. 456 Returns False when module is not available, True when the module initialized 457 successfully, raises an exception otherwise. 458 """ 459 manager_obj = self.get_object('/Manager') 460 manager = self.get_interface(manager_obj, '.Manager') 461 try: 462 manager.EnableModule(module, dbus.Boolean(True)) 463 return True 464 except dbus.exceptions.DBusException as e: 465 msg = r"Error initializing module '%s': .*\.so: cannot open shared object file: No such file or directory" % module 466 if re.search(msg, e.get_dbus_message()): 467 return False 468 else: 469 raise 470 471 @classmethod 472 def ay_to_str(self, ay): 473 """Convert a bytearray (terminated with '\0') to a string""" 474 475 return ''.join(chr(x) for x in ay[:-1]) 476 477 @classmethod 478 def str_to_ay(self, string, terminate=True): 479 """Convert a string to a bytearray (terminated with '\0')""" 480 481 if terminate: 482 string += '\0' 483 484 return dbus.Array([dbus.Byte(ord(c)) for c in string], 485 signature=dbus.Signature('y'), variant_level=1) 486 487 @classmethod 488 def bytes_to_ay(self, bytes): 489 """Convert Python bytes to a DBus bytearray""" 490 491 return dbus.Array([dbus.Byte(b) for b in bytes], 492 signature=dbus.Signature('y'), variant_level=1) 493 494 @classmethod 495 def set_udev_properties(self, device, props): 496 """Sets one or more udev properties for the 'device' identified by its serial number. 497 Note that this overwrites previously set properties. Pass props=None to remove 498 the rules. 499 500 :type props: dict 501 """ 502 UDISKS_UDEV_RULES = "/run/udev/rules.d/99-udisks_test.rules" 503 504 udev = GUdev.Client() 505 dev = udev.query_by_device_file(device) 506 serial = dev.get_property("ID_SERIAL") 507 508 try: 509 os.makedirs("/run/udev/rules.d/") 510 except OSError: 511 # already exists 512 pass 513 514 if props: 515 rules = "" 516 for i in props: 517 rules += ', ENV{%s}="%s"' % (i, props[i]) 518 self.write_file(UDISKS_UDEV_RULES, 519 'ENV{ID_SERIAL}=="%s"%s\n' % (serial, rules)) 520 else: 521 self.remove_file(UDISKS_UDEV_RULES, ignore_nonexistent=True) 522 self.run_command("udevadm control --reload") 523 uevent_path = os.path.join(dev.get_sysfs_path(), "uevent") 524 self.write_file(uevent_path, "change\n") 525 self.udev_settle() 526 # FIXME: need to give udisksd some time to process the uevent 527 time.sleep(1) 528 529 @classmethod 530 def assertHasIface(self, obj, iface): 531 obj_intro = dbus.Interface(obj, "org.freedesktop.DBus.Introspectable") 532 intro_data = obj_intro.Introspect() 533 534 for _ in range(20): 535 if ('interface name="%s"' % iface) in intro_data: 536 return 537 time.sleep(0.5) 538 539 raise AssertionError("Object '%s' has no interface '%s'" % (obj.object_path, iface)) 540 541 def assertStartswith(self, val, prefix): 542 if not val.startswith(prefix): 543 raise AssertionError("'%s' does not start with '%s'" % (val, prefix)) 544 545 546class FlightRecorder(object): 547 """Context manager for recording data/logs 548 549 This is the abstract implementation that does nothing. Subclasses are 550 expected to override the methods below to actually do something useful. 551 552 """ 553 554 def __init__(self, desc): 555 """ 556 :param str desc: description of the recorder 557 558 """ 559 self._desc = desc 560 561 def _start(self): 562 """Start recording""" 563 564 def _stop(self): 565 """Stop recording""" 566 pass 567 568 def _save(self): 569 """Save the record""" 570 pass 571 572 def __enter__(self): 573 self._start() 574 575 def __exit__(self, exc_type, exc_val, exc_tb): 576 self._stop() 577 self._save() 578 579 # Returning False means that the exception we have potentially been 580 # given as arguments was not handled 581 return False 582 583class CmdFlightRecorder(FlightRecorder): 584 """Flight recorder running a command and gathering its standard and error output""" 585 586 def __init__(self, desc, argv, store): 587 """ 588 :param str desc: description of the recorder 589 :param argv: command and arguments to run 590 :type argv: list of str 591 :param store: a list-like object to append the data/logs to 592 593 """ 594 super(CmdFlightRecorder, self).__init__(desc) 595 self._argv = argv 596 self._store = store 597 self._proc = None 598 599 def _start(self): 600 self._proc = subprocess.Popen(self._argv, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 601 602 def _stop(self): 603 self._proc.terminate() 604 605 def _save(self): 606 # err is in out (see above) 607 out, _err = self._proc.communicate() 608 rec = '<<<<< ' + self._desc + ' >>>>>' + '\n' + out.decode() + '\n\n' 609 self._store.append(rec) 610 611class JournalRecorder(FlightRecorder): 612 """Flight recorder for gathering logs (journal)""" 613 614 def __init__(self, desc, store): 615 """ 616 :param str desc: description of the recorder 617 :param store: a list-like object to append the data/logs to 618 619 """ 620 super(JournalRecorder, self).__init__(desc) 621 self._store = store 622 self._started = None 623 self._stopped = None 624 625 def _start(self): 626 self._started = monotonic() 627 628 def _stop(self): 629 self._stopped = monotonic() 630 631 def _save(self): 632 j = journal.Reader(converters={"MESSAGE": lambda x: x.decode(errors="replace")}) 633 j.this_boot() 634 j.seek_monotonic(self._started) 635 journal_data = "" 636 637 entry = j.get_next() 638 # entry["__MONOTONIC_TIMESTAMP"] is a tuple of (datetime.timedelta, boot_id) 639 while entry and entry["__MONOTONIC_TIMESTAMP"][0].seconds <= int(self._stopped): 640 if "_COMM" in entry and "_PID" in entry: 641 source = "%s[%d]" % (entry["_COMM"], entry["_PID"]) 642 else: 643 source = "kernel" 644 journal_data += "%s[%0.8f] %s: %s\n" % (entry["__REALTIME_TIMESTAMP"].strftime("%H:%M:%S"), 645 entry["__MONOTONIC_TIMESTAMP"][0].total_seconds(), 646 source, entry["MESSAGE"]) 647 entry = j.get_next() 648 rec = '<<<<< ' + self._desc + ' >>>>>' + '\n' + journal_data + '\n\n\n' 649 self._store.append(rec) 650 651 652class TestTags(Enum): 653 ALL = "all" # "default" tag for running all tests 654 SLOW = "slow" # slow tests 655 UNSTABLE = "unstable" # randomly failing tests 656 UNSAFE = "unsafe" # tests that change system configuration 657 NOSTORAGE = "nostorage" # tests that don't work with storage 658 EXTRADEPS = "extradeps" # tests that require special configuration and/or device to run 659 660 @classmethod 661 def get_tags(cls): 662 return [t.value for t in cls.__members__.values()] 663 664 @classmethod 665 def get_tag_by_value(cls, value): 666 tag = next((t for t in cls.__members__.values() if t.value == value), None) 667 668 if not tag: 669 raise ValueError('Unknown value "%s"' % value) 670 671 return tag 672 673 674def tag_test(*tags): 675 def decorator(func): 676 func.slow = TestTags.SLOW in tags 677 func.unstable = TestTags.UNSTABLE in tags 678 func.unsafe = TestTags.UNSAFE in tags 679 func.nostorage = TestTags.NOSTORAGE in tags 680 func.extradeps = TestTags.EXTRADEPS in tags 681 682 return func 683 684 return decorator 685