1# Copyright (C) 2009 Nokia Corporation 2# Copyright (C) 2009 Collabora Ltd. 3# 4# This library is free software; you can redistribute it and/or 5# modify it under the terms of the GNU Lesser General Public 6# License as published by the Free Software Foundation; either 7# version 2.1 of the License, or (at your option) any later version. 8# 9# This library is distributed in the hope that it will be useful, but 10# WITHOUT ANY WARRANTY; without even the implied warranty of 11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 12# Lesser General Public License for more details. 13# 14# You should have received a copy of the GNU Lesser General Public 15# License along with this library; if not, write to the Free Software 16# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 17# 02110-1301 USA 18 19""" 20Infrastructure code for testing Mission Control 21""" 22 23from twisted.internet import gireactor 24from twisted.internet.protocol import Protocol, Factory, ClientFactory 25gireactor.install() 26import sys 27 28import pprint 29import unittest 30 31import dbus 32import dbus.lowlevel 33import dbus.glib 34 35from twisted.internet import reactor 36 37tp_name_prefix = 'org.freedesktop.Telepathy' 38tp_path_prefix = '/org/freedesktop/Telepathy' 39 40class Event: 41 def __init__(self, type, **kw): 42 self.__dict__.update(kw) 43 self.type = type 44 45 def __str__(self): 46 return '\n'.join([ str(type(self)) ] + format_event(self)) 47 48def format_event(event): 49 ret = ['- type %s' % event.type] 50 51 for key in dir(event): 52 if key != 'type' and not key.startswith('_'): 53 ret.append('- %s: %s' % ( 54 key, pprint.pformat(getattr(event, key)))) 55 56 if key == 'error': 57 ret.append('%s' % getattr(event, key)) 58 59 return ret 60 61class EventPattern: 62 def __init__(self, type, **properties): 63 self.type = type 64 self.predicate = lambda x: True 65 if 'predicate' in properties: 66 self.predicate = properties['predicate'] 67 del properties['predicate'] 68 self.properties = properties 69 70 def __repr__(self): 71 properties = dict(self.properties) 72 73 if self.predicate: 74 properties['predicate'] = self.predicate 75 76 return '%s(%r, **%r)' % ( 77 self.__class__.__name__, self.type, properties) 78 79 def match(self, event): 80 if event.type != self.type: 81 return False 82 83 for key, value in self.properties.items(): 84 try: 85 if getattr(event, key) != value: 86 return False 87 except AttributeError: 88 return False 89 90 if self.predicate(event): 91 return True 92 93 return False 94 95 96class TimeoutError(Exception): 97 pass 98 99class ForbiddenEventOccurred(Exception): 100 def __init__(self, event): 101 Exception.__init__(self) 102 self.event = event 103 104 def __str__(self): 105 return '\n' + '\n'.join(format_event(self.event)) 106 107class BaseEventQueue: 108 """Abstract event queue base class. 109 110 Implement the wait() method to have something that works. 111 """ 112 113 def __init__(self, timeout=None): 114 self.verbose = False 115 self.past_events = [] 116 self.forbidden_events = set() 117 118 if timeout is None: 119 self.timeout = 5 120 else: 121 self.timeout = timeout 122 123 def log(self, s): 124 if self.verbose: 125 print(s) 126 127 def log_event(self, event): 128 if self.verbose: 129 self.log('got event:') 130 131 if self.verbose: 132 list(map(self.log, format_event(event))) 133 134 def flush_past_events(self): 135 self.past_events = [] 136 137 def expect_racy(self, type, **kw): 138 pattern = EventPattern(type, **kw) 139 140 for event in self.past_events: 141 if pattern.match(event): 142 self.log('past event handled') 143 list(map(self.log, format_event(event))) 144 self.log('') 145 self.past_events.remove(event) 146 return event 147 148 return self.expect(type, **kw) 149 150 def forbid_events(self, patterns): 151 """ 152 Add patterns (an iterable of EventPattern) to the set of forbidden 153 events. If a forbidden event occurs during an expect or expect_many, 154 the test will fail. 155 """ 156 self.forbidden_events.update(set(patterns)) 157 158 def unforbid_events(self, patterns): 159 """ 160 Remove 'patterns' (an iterable of EventPattern) from the set of 161 forbidden events. These must be the same EventPattern pointers that 162 were passed to forbid_events. 163 """ 164 self.forbidden_events.difference_update(set(patterns)) 165 166 def _check_forbidden(self, event): 167 for e in self.forbidden_events: 168 if e.match(event): 169 raise ForbiddenEventOccurred(event) 170 171 def expect(self, type, **kw): 172 pattern = EventPattern(type, **kw) 173 174 while True: 175 event = self.wait() 176 self.log_event(event) 177 self._check_forbidden(event) 178 179 if pattern.match(event): 180 self.log('handled') 181 self.log('') 182 return event 183 184 self.past_events.append(event) 185 self.log('not handled') 186 self.log('') 187 188 def expect_many(self, *patterns): 189 ret = [None] * len(patterns) 190 191 while None in ret: 192 try: 193 event = self.wait() 194 except TimeoutError: 195 self.log('timeout') 196 self.log('still expecting:') 197 for i, pattern in enumerate(patterns): 198 if ret[i] is None: 199 self.log(' - %r' % pattern) 200 raise 201 self.log_event(event) 202 self._check_forbidden(event) 203 204 for i, pattern in enumerate(patterns): 205 if ret[i] is None and pattern.match(event): 206 self.log('handled') 207 self.log('') 208 ret[i] = event 209 break 210 else: 211 self.past_events.append(event) 212 self.log('not handled') 213 self.log('') 214 215 return ret 216 217 def demand(self, type, **kw): 218 pattern = EventPattern(type, **kw) 219 220 event = self.wait() 221 self.log_event(event) 222 223 if pattern.match(event): 224 self.log('handled') 225 self.log('') 226 return event 227 228 self.log('not handled') 229 raise RuntimeError('expected %r, got %r' % (pattern, event)) 230 231class IteratingEventQueue(BaseEventQueue): 232 """Event queue that works by iterating the Twisted reactor.""" 233 234 def __init__(self, timeout=None): 235 BaseEventQueue.__init__(self, timeout) 236 self.events = [] 237 self._dbus_method_impls = [] 238 self._buses = [] 239 # a message filter which will claim we handled everything 240 self._dbus_dev_null = \ 241 lambda bus, message: dbus.lowlevel.HANDLER_RESULT_HANDLED 242 243 def wait(self): 244 stop = [False] 245 246 def later(): 247 stop[0] = True 248 249 delayed_call = reactor.callLater(self.timeout, later) 250 251 while (not self.events) and (not stop[0]): 252 reactor.iterate(0.1) 253 254 if self.events: 255 delayed_call.cancel() 256 return self.events.pop(0) 257 else: 258 raise TimeoutError 259 260 def append(self, event): 261 self.events.append(event) 262 263 # compatibility 264 handle_event = append 265 266 def add_dbus_method_impl(self, cb, bus=None, **kwargs): 267 if bus is None: 268 bus = self._buses[0] 269 270 self._dbus_method_impls.append( 271 (EventPattern('dbus-method-call', **kwargs), cb)) 272 273 def dbus_emit(self, path, iface, name, *a, **k): 274 bus = k.pop('bus', self._buses[0]) 275 assert 'signature' in k, k 276 message = dbus.lowlevel.SignalMessage(path, iface, name) 277 message.append(*a, **k) 278 bus.send_message(message) 279 280 def dbus_return(self, in_reply_to, *a, **k): 281 bus = k.pop('bus', self._buses[0]) 282 assert 'signature' in k, k 283 reply = dbus.lowlevel.MethodReturnMessage(in_reply_to) 284 reply.append(*a, **k) 285 bus.send_message(reply) 286 287 def dbus_raise(self, in_reply_to, name, message=None, bus=None): 288 if bus is None: 289 bus = self._buses[0] 290 291 reply = dbus.lowlevel.ErrorMessage(in_reply_to, name, message) 292 bus.send_message(reply) 293 294 def attach_to_bus(self, bus): 295 if not self._buses: 296 # first-time setup 297 self._dbus_filter_bound_method = self._dbus_filter 298 299 self._buses.append(bus) 300 301 # Only subscribe to messages on the first bus connection (assumed to 302 # be the shared session bus connection used by the simulated connection 303 # manager and most of the test suite), not on subsequent bus 304 # connections (assumed to represent extra clients). 305 # 306 # When we receive a method call on the other bus connections, ignore 307 # it - the eavesdropping filter installed on the first bus connection 308 # will see it too. 309 # 310 # This is highly counter-intuitive, but it means our messages are in 311 # a guaranteed order (we don't have races between messages arriving on 312 # various connections). 313 if len(self._buses) > 1: 314 bus.add_message_filter(self._dbus_dev_null) 315 return 316 317 try: 318 # for dbus > 1.5 319 bus.add_match_string("eavesdrop=true,type='signal'") 320 except dbus.DBusException: 321 bus.add_match_string("type='signal'") 322 bus.add_match_string("type='method_call'") 323 else: 324 bus.add_match_string("eavesdrop=true,type='method_call'") 325 326 bus.add_message_filter(self._dbus_filter_bound_method) 327 328 bus.add_signal_receiver( 329 lambda *args, **kw: 330 self.append( 331 Event('dbus-signal', 332 path=unwrap(kw['path']), 333 signal=kw['member'], 334 args=list(map(unwrap, args)), 335 interface=kw['interface'])), 336 None, 337 None, 338 None, 339 path_keyword='path', 340 member_keyword='member', 341 interface_keyword='interface', 342 byte_arrays=True, 343 ) 344 345 def cleanup(self): 346 if self._buses: 347 self._buses[0].remove_message_filter(self._dbus_filter_bound_method) 348 for bus in self._buses[1:]: 349 bus.remove_message_filter(self._dbus_dev_null) 350 351 self._buses = [] 352 self._dbus_method_impls = [] 353 354 def _dbus_filter(self, bus, message): 355 if isinstance(message, dbus.lowlevel.MethodCallMessage): 356 357 destination = message.get_destination() 358 sender = message.get_sender() 359 360 if (destination == 'org.freedesktop.DBus' or 361 sender == self._buses[0].get_unique_name()): 362 # suppress reply and don't make an Event 363 return dbus.lowlevel.HANDLER_RESULT_HANDLED 364 365 e = Event('dbus-method-call', message=message, 366 interface=message.get_interface(), path=message.get_path(), 367 raw_args=message.get_args_list(byte_arrays=True), 368 args=list(map(unwrap, message.get_args_list(byte_arrays=True))), 369 destination=str(destination), 370 method=message.get_member(), 371 sender=message.get_sender(), 372 handled=False) 373 374 for pair in self._dbus_method_impls: 375 pattern, cb = pair 376 if pattern.match(e): 377 cb(e) 378 e.handled = True 379 break 380 381 self.append(e) 382 383 return dbus.lowlevel.HANDLER_RESULT_HANDLED 384 385 return dbus.lowlevel.HANDLER_RESULT_NOT_YET_HANDLED 386 387class TestEventQueue(BaseEventQueue): 388 def __init__(self, events): 389 BaseEventQueue.__init__(self) 390 self.events = events 391 392 def wait(self): 393 if self.events: 394 return self.events.pop(0) 395 else: 396 raise TimeoutError 397 398class EventQueueTest(unittest.TestCase): 399 def test_expect(self): 400 queue = TestEventQueue([Event('foo'), Event('bar')]) 401 assert queue.expect('foo').type == 'foo' 402 assert queue.expect('bar').type == 'bar' 403 404 def test_expect_many(self): 405 queue = TestEventQueue([Event('foo'), Event('bar')]) 406 bar, foo = queue.expect_many( 407 EventPattern('bar'), 408 EventPattern('foo')) 409 assert bar.type == 'bar' 410 assert foo.type == 'foo' 411 412 def test_expect_many2(self): 413 # Test that events are only matched against patterns that haven't yet 414 # been matched. This tests a regression. 415 queue = TestEventQueue([Event('foo', x=1), Event('foo', x=2)]) 416 foo1, foo2 = queue.expect_many( 417 EventPattern('foo'), 418 EventPattern('foo')) 419 assert foo1.type == 'foo' and foo1.x == 1 420 assert foo2.type == 'foo' and foo2.x == 2 421 422 def test_timeout(self): 423 queue = TestEventQueue([]) 424 self.assertRaises(TimeoutError, queue.expect, 'foo') 425 426 def test_demand(self): 427 queue = TestEventQueue([Event('foo'), Event('bar')]) 428 foo = queue.demand('foo') 429 assert foo.type == 'foo' 430 431 def test_demand_fail(self): 432 queue = TestEventQueue([Event('foo'), Event('bar')]) 433 self.assertRaises(RuntimeError, queue.demand, 'bar') 434 435def unwrap(x): 436 """Hack to unwrap D-Bus values, so that they're easier to read when 437 printed.""" 438 439 if isinstance(x, list): 440 return list(map(unwrap, x)) 441 442 if isinstance(x, tuple): 443 return tuple(map(unwrap, x)) 444 445 if isinstance(x, dict): 446 return dict([(unwrap(k), unwrap(v)) for k, v in x.items()]) 447 448 if isinstance(x, dbus.Boolean): 449 return bool(x) 450 451 for t in [str, str, int, int, float]: 452 if isinstance(x, t): 453 return t(x) 454 455 return x 456 457def call_async(test, proxy, method, *args, **kw): 458 """Call a D-Bus method asynchronously and generate an event for the 459 resulting method return/error.""" 460 461 def reply_func(*ret): 462 test.handle_event(Event('dbus-return', method=method, 463 value=unwrap(ret))) 464 465 def error_func(err): 466 test.handle_event(Event('dbus-error', method=method, error=err, 467 name=err.get_dbus_name(), message=str(err))) 468 469 method_proxy = getattr(proxy, method) 470 kw.update({'reply_handler': reply_func, 'error_handler': error_func}) 471 method_proxy(*args, **kw) 472 473def sync_dbus(bus, q, proxy): 474 # Dummy D-Bus method call. We can't use DBus.Peer.Ping() because libdbus 475 # replies to that message immediately, rather than handing it up to 476 # dbus-glib and thence the application, which means that Ping()ing the 477 # application doesn't ensure that it's processed all D-Bus messages prior 478 # to our ping. 479 call_async(q, dbus.Interface(proxy, 'org.freedesktop.Telepathy.Tests'), 480 'DummySyncDBus') 481 q.expect('dbus-error', method='DummySyncDBus') 482 483class ProxyWrapper: 484 def __init__(self, object, default, others): 485 self.object = object 486 self.default_interface = dbus.Interface(object, default) 487 self.Properties = dbus.Interface(object, dbus.PROPERTIES_IFACE) 488 self.TpProperties = \ 489 dbus.Interface(object, tp_name_prefix + '.Properties') 490 self.interfaces = dict([ 491 (name, dbus.Interface(object, iface)) 492 for name, iface in others.items()]) 493 494 def __getattr__(self, name): 495 if name in self.interfaces: 496 return self.interfaces[name] 497 498 if name in self.object.__dict__: 499 return getattr(self.object, name) 500 501 return getattr(self.default_interface, name) 502 503def wrap_channel(chan, type_, extra=None): 504 interfaces = { 505 type_: tp_name_prefix + '.Channel.Type.' + type_, 506 'Group': tp_name_prefix + '.Channel.Interface.Group', 507 } 508 509 if extra: 510 interfaces.update(dict([ 511 (name, tp_name_prefix + '.Channel.Interface.' + name) 512 for name in extra])) 513 514 return ProxyWrapper(chan, tp_name_prefix + '.Channel', interfaces) 515 516def make_connection(bus, event_func, name, proto, params): 517 cm = bus.get_object( 518 tp_name_prefix + '.ConnectionManager.%s' % name, 519 tp_path_prefix + '/ConnectionManager/%s' % name) 520 cm_iface = dbus.Interface(cm, tp_name_prefix + '.ConnectionManager') 521 522 connection_name, connection_path = cm_iface.RequestConnection( 523 proto, params) 524 conn = wrap_connection(bus.get_object(connection_name, connection_path)) 525 526 return conn 527 528def make_channel_proxy(conn, path, iface): 529 bus = dbus.SessionBus() 530 chan = bus.get_object(conn.object.bus_name, path) 531 chan = dbus.Interface(chan, tp_name_prefix + '.' + iface) 532 return chan 533 534# block_reading can be used if the test want to choose when we start to read 535# data from the socket. 536class EventProtocol(Protocol): 537 def __init__(self, queue=None, block_reading=False): 538 self.queue = queue 539 self.block_reading = block_reading 540 541 def dataReceived(self, data): 542 if self.queue is not None: 543 self.queue.handle_event(Event('socket-data', protocol=self, 544 data=data)) 545 546 def sendData(self, data): 547 self.transport.write(data) 548 549 def connectionMade(self): 550 if self.block_reading: 551 self.transport.stopReading() 552 553 def connectionLost(self, reason=None): 554 if self.queue is not None: 555 self.queue.handle_event(Event('socket-disconnected', protocol=self)) 556 557class EventProtocolFactory(Factory): 558 def __init__(self, queue, block_reading=False): 559 self.queue = queue 560 self.block_reading = block_reading 561 562 def _create_protocol(self): 563 return EventProtocol(self.queue, self.block_reading) 564 565 def buildProtocol(self, addr): 566 proto = self._create_protocol() 567 self.queue.handle_event(Event('socket-connected', protocol=proto)) 568 return proto 569 570class EventProtocolClientFactory(EventProtocolFactory, ClientFactory): 571 pass 572 573def watch_tube_signals(q, tube): 574 def got_signal_cb(*args, **kwargs): 575 q.handle_event(Event('tube-signal', 576 path=kwargs['path'], 577 signal=kwargs['member'], 578 args=list(map(unwrap, args)), 579 tube=tube)) 580 581 tube.add_signal_receiver(got_signal_cb, 582 path_keyword='path', member_keyword='member', 583 byte_arrays=True) 584 585def pretty(x): 586 return pprint.pformat(unwrap(x)) 587 588def assertEquals(expected, value): 589 if expected != value: 590 raise AssertionError( 591 "expected:\n%s\ngot:\n%s" % (pretty(expected), pretty(value))) 592 593def assertNotEquals(expected, value): 594 if expected == value: 595 raise AssertionError( 596 "expected something other than:\n%s" % pretty(value)) 597 598def assertContains(element, value): 599 if element not in value: 600 raise AssertionError( 601 "expected:\n%s\nin:\n%s" % (pretty(element), pretty(value))) 602 603def assertDoesNotContain(element, value): 604 if element in value: 605 raise AssertionError( 606 "expected:\n%s\nnot in:\n%s" % (pretty(element), pretty(value))) 607 608def assertLength(length, value): 609 if len(value) != length: 610 raise AssertionError("expected: length %d, got length %d:\n%s" % ( 611 length, len(value), pretty(value))) 612 613def assertFlagsSet(flags, value): 614 masked = value & flags 615 if masked != flags: 616 raise AssertionError( 617 "expected flags %u, of which only %u are set in %u" % ( 618 flags, masked, value)) 619 620def assertFlagsUnset(flags, value): 621 masked = value & flags 622 if masked != 0: 623 raise AssertionError( 624 "expected none of flags %u, but %u are set in %u" % ( 625 flags, masked, value)) 626 627def assertSameSets(expected, value): 628 exp_set = set(expected) 629 val_set = set(value) 630 631 if exp_set != val_set: 632 raise AssertionError( 633 "expected contents:\n%s\ngot:\n%s" % ( 634 pretty(exp_set), pretty(val_set))) 635 636 637def install_colourer(): 638 def red(s): 639 return '\x1b[31m%s\x1b[0m' % s 640 641 def green(s): 642 return '\x1b[32m%s\x1b[0m' % s 643 644 patterns = { 645 'handled': green, 646 'not handled': red, 647 } 648 649 class Colourer: 650 def __init__(self, fh, patterns): 651 self.fh = fh 652 self.patterns = patterns 653 654 def write(self, s): 655 f = self.patterns.get(s, lambda x: x) 656 self.fh.write(f(s)) 657 658 sys.stdout = Colourer(sys.stdout, patterns) 659 return sys.stdout 660 661 662 663if __name__ == '__main__': 664 unittest.main() 665 666