1# Copyright (C) 2009 Nokia Corporation
2# Copyright (C) 2009 Collabora Ltd.
3#
4# This library is free software; you can redistribute it and/or
5# modify it under the terms of the GNU Lesser General Public
6# License as published by the Free Software Foundation; either
7# version 2.1 of the License, or (at your option) any later version.
8#
9# This library is distributed in the hope that it will be useful, but
10# WITHOUT ANY WARRANTY; without even the implied warranty of
11# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12# Lesser General Public License for more details.
13#
14# You should have received a copy of the GNU Lesser General Public
15# License along with this library; if not, write to the Free Software
16# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
17# 02110-1301 USA
18
19"""
20Infrastructure code for testing Mission Control
21"""
22
23from twisted.internet import gireactor
24from twisted.internet.protocol import Protocol, Factory, ClientFactory
25gireactor.install()
26import sys
27
28import pprint
29import unittest
30
31import dbus
32import dbus.lowlevel
33import dbus.glib
34
35from twisted.internet import reactor
36
37tp_name_prefix = 'org.freedesktop.Telepathy'
38tp_path_prefix = '/org/freedesktop/Telepathy'
39
40class Event:
41    def __init__(self, type, **kw):
42        self.__dict__.update(kw)
43        self.type = type
44
45    def __str__(self):
46        return '\n'.join([ str(type(self)) ] + format_event(self))
47
48def format_event(event):
49    ret = ['- type %s' % event.type]
50
51    for key in dir(event):
52        if key != 'type' and not key.startswith('_'):
53            ret.append('- %s: %s' % (
54                key, pprint.pformat(getattr(event, key))))
55
56            if key == 'error':
57                ret.append('%s' % getattr(event, key))
58
59    return ret
60
61class EventPattern:
62    def __init__(self, type, **properties):
63        self.type = type
64        self.predicate = lambda x: True
65        if 'predicate' in properties:
66            self.predicate = properties['predicate']
67            del properties['predicate']
68        self.properties = properties
69
70    def __repr__(self):
71        properties = dict(self.properties)
72
73        if self.predicate:
74            properties['predicate'] = self.predicate
75
76        return '%s(%r, **%r)' % (
77            self.__class__.__name__, self.type, properties)
78
79    def match(self, event):
80        if event.type != self.type:
81            return False
82
83        for key, value in self.properties.items():
84            try:
85                if getattr(event, key) != value:
86                    return False
87            except AttributeError:
88                return False
89
90        if self.predicate(event):
91            return True
92
93        return False
94
95
96class TimeoutError(Exception):
97    pass
98
99class ForbiddenEventOccurred(Exception):
100    def __init__(self, event):
101        Exception.__init__(self)
102        self.event = event
103
104    def __str__(self):
105        return '\n' + '\n'.join(format_event(self.event))
106
107class BaseEventQueue:
108    """Abstract event queue base class.
109
110    Implement the wait() method to have something that works.
111    """
112
113    def __init__(self, timeout=None):
114        self.verbose = False
115        self.past_events = []
116        self.forbidden_events = set()
117
118        if timeout is None:
119            self.timeout = 5
120        else:
121            self.timeout = timeout
122
123    def log(self, s):
124        if self.verbose:
125            print(s)
126
127    def log_event(self, event):
128        if self.verbose:
129            self.log('got event:')
130
131            if self.verbose:
132                list(map(self.log, format_event(event)))
133
134    def flush_past_events(self):
135        self.past_events = []
136
137    def expect_racy(self, type, **kw):
138        pattern = EventPattern(type, **kw)
139
140        for event in self.past_events:
141            if pattern.match(event):
142                self.log('past event handled')
143                list(map(self.log, format_event(event)))
144                self.log('')
145                self.past_events.remove(event)
146                return event
147
148        return self.expect(type, **kw)
149
150    def forbid_events(self, patterns):
151        """
152        Add patterns (an iterable of EventPattern) to the set of forbidden
153        events. If a forbidden event occurs during an expect or expect_many,
154        the test will fail.
155        """
156        self.forbidden_events.update(set(patterns))
157
158    def unforbid_events(self, patterns):
159        """
160        Remove 'patterns' (an iterable of EventPattern) from the set of
161        forbidden events. These must be the same EventPattern pointers that
162        were passed to forbid_events.
163        """
164        self.forbidden_events.difference_update(set(patterns))
165
166    def _check_forbidden(self, event):
167        for e in self.forbidden_events:
168            if e.match(event):
169                raise ForbiddenEventOccurred(event)
170
171    def expect(self, type, **kw):
172        pattern = EventPattern(type, **kw)
173
174        while True:
175            event = self.wait()
176            self.log_event(event)
177            self._check_forbidden(event)
178
179            if pattern.match(event):
180                self.log('handled')
181                self.log('')
182                return event
183
184            self.past_events.append(event)
185            self.log('not handled')
186            self.log('')
187
188    def expect_many(self, *patterns):
189        ret = [None] * len(patterns)
190
191        while None in ret:
192            try:
193                event = self.wait()
194            except TimeoutError:
195                self.log('timeout')
196                self.log('still expecting:')
197                for i, pattern in enumerate(patterns):
198                    if ret[i] is None:
199                        self.log(' - %r' % pattern)
200                raise
201            self.log_event(event)
202            self._check_forbidden(event)
203
204            for i, pattern in enumerate(patterns):
205                if ret[i] is None and pattern.match(event):
206                    self.log('handled')
207                    self.log('')
208                    ret[i] = event
209                    break
210            else:
211                self.past_events.append(event)
212                self.log('not handled')
213                self.log('')
214
215        return ret
216
217    def demand(self, type, **kw):
218        pattern = EventPattern(type, **kw)
219
220        event = self.wait()
221        self.log_event(event)
222
223        if pattern.match(event):
224            self.log('handled')
225            self.log('')
226            return event
227
228        self.log('not handled')
229        raise RuntimeError('expected %r, got %r' % (pattern, event))
230
231class IteratingEventQueue(BaseEventQueue):
232    """Event queue that works by iterating the Twisted reactor."""
233
234    def __init__(self, timeout=None):
235        BaseEventQueue.__init__(self, timeout)
236        self.events = []
237        self._dbus_method_impls = []
238        self._buses = []
239        # a message filter which will claim we handled everything
240        self._dbus_dev_null = \
241                lambda bus, message: dbus.lowlevel.HANDLER_RESULT_HANDLED
242
243    def wait(self):
244        stop = [False]
245
246        def later():
247            stop[0] = True
248
249        delayed_call = reactor.callLater(self.timeout, later)
250
251        while (not self.events) and (not stop[0]):
252            reactor.iterate(0.1)
253
254        if self.events:
255            delayed_call.cancel()
256            return self.events.pop(0)
257        else:
258            raise TimeoutError
259
260    def append(self, event):
261        self.events.append(event)
262
263    # compatibility
264    handle_event = append
265
266    def add_dbus_method_impl(self, cb, bus=None, **kwargs):
267        if bus is None:
268            bus = self._buses[0]
269
270        self._dbus_method_impls.append(
271                (EventPattern('dbus-method-call', **kwargs), cb))
272
273    def dbus_emit(self, path, iface, name, *a, **k):
274        bus = k.pop('bus', self._buses[0])
275        assert 'signature' in k, k
276        message = dbus.lowlevel.SignalMessage(path, iface, name)
277        message.append(*a, **k)
278        bus.send_message(message)
279
280    def dbus_return(self, in_reply_to, *a, **k):
281        bus = k.pop('bus', self._buses[0])
282        assert 'signature' in k, k
283        reply = dbus.lowlevel.MethodReturnMessage(in_reply_to)
284        reply.append(*a, **k)
285        bus.send_message(reply)
286
287    def dbus_raise(self, in_reply_to, name, message=None, bus=None):
288        if bus is None:
289            bus = self._buses[0]
290
291        reply = dbus.lowlevel.ErrorMessage(in_reply_to, name, message)
292        bus.send_message(reply)
293
294    def attach_to_bus(self, bus):
295        if not self._buses:
296            # first-time setup
297            self._dbus_filter_bound_method = self._dbus_filter
298
299        self._buses.append(bus)
300
301        # Only subscribe to messages on the first bus connection (assumed to
302        # be the shared session bus connection used by the simulated connection
303        # manager and most of the test suite), not on subsequent bus
304        # connections (assumed to represent extra clients).
305        #
306        # When we receive a method call on the other bus connections, ignore
307        # it - the eavesdropping filter installed on the first bus connection
308        # will see it too.
309        #
310        # This is highly counter-intuitive, but it means our messages are in
311        # a guaranteed order (we don't have races between messages arriving on
312        # various connections).
313        if len(self._buses) > 1:
314            bus.add_message_filter(self._dbus_dev_null)
315            return
316
317        try:
318            # for dbus > 1.5
319            bus.add_match_string("eavesdrop=true,type='signal'")
320        except dbus.DBusException:
321            bus.add_match_string("type='signal'")
322            bus.add_match_string("type='method_call'")
323        else:
324            bus.add_match_string("eavesdrop=true,type='method_call'")
325
326        bus.add_message_filter(self._dbus_filter_bound_method)
327
328        bus.add_signal_receiver(
329                lambda *args, **kw:
330                    self.append(
331                        Event('dbus-signal',
332                            path=unwrap(kw['path']),
333                            signal=kw['member'],
334                            args=list(map(unwrap, args)),
335                            interface=kw['interface'])),
336                None,
337                None,
338                None,
339                path_keyword='path',
340                member_keyword='member',
341                interface_keyword='interface',
342                byte_arrays=True,
343                )
344
345    def cleanup(self):
346        if self._buses:
347            self._buses[0].remove_message_filter(self._dbus_filter_bound_method)
348        for bus in self._buses[1:]:
349            bus.remove_message_filter(self._dbus_dev_null)
350
351        self._buses = []
352        self._dbus_method_impls = []
353
354    def _dbus_filter(self, bus, message):
355        if isinstance(message, dbus.lowlevel.MethodCallMessage):
356
357            destination = message.get_destination()
358            sender = message.get_sender()
359
360            if (destination == 'org.freedesktop.DBus' or
361                    sender == self._buses[0].get_unique_name()):
362                # suppress reply and don't make an Event
363                return dbus.lowlevel.HANDLER_RESULT_HANDLED
364
365            e = Event('dbus-method-call', message=message,
366                interface=message.get_interface(), path=message.get_path(),
367                raw_args=message.get_args_list(byte_arrays=True),
368                args=list(map(unwrap, message.get_args_list(byte_arrays=True))),
369                destination=str(destination),
370                method=message.get_member(),
371                sender=message.get_sender(),
372                handled=False)
373
374            for pair in self._dbus_method_impls:
375                pattern, cb = pair
376                if pattern.match(e):
377                    cb(e)
378                    e.handled = True
379                    break
380
381            self.append(e)
382
383            return dbus.lowlevel.HANDLER_RESULT_HANDLED
384
385        return dbus.lowlevel.HANDLER_RESULT_NOT_YET_HANDLED
386
387class TestEventQueue(BaseEventQueue):
388    def __init__(self, events):
389        BaseEventQueue.__init__(self)
390        self.events = events
391
392    def wait(self):
393        if self.events:
394            return self.events.pop(0)
395        else:
396            raise TimeoutError
397
398class EventQueueTest(unittest.TestCase):
399    def test_expect(self):
400        queue = TestEventQueue([Event('foo'), Event('bar')])
401        assert queue.expect('foo').type == 'foo'
402        assert queue.expect('bar').type == 'bar'
403
404    def test_expect_many(self):
405        queue = TestEventQueue([Event('foo'), Event('bar')])
406        bar, foo = queue.expect_many(
407            EventPattern('bar'),
408            EventPattern('foo'))
409        assert bar.type == 'bar'
410        assert foo.type == 'foo'
411
412    def test_expect_many2(self):
413        # Test that events are only matched against patterns that haven't yet
414        # been matched. This tests a regression.
415        queue = TestEventQueue([Event('foo', x=1), Event('foo', x=2)])
416        foo1, foo2 = queue.expect_many(
417            EventPattern('foo'),
418            EventPattern('foo'))
419        assert foo1.type == 'foo' and foo1.x == 1
420        assert foo2.type == 'foo' and foo2.x == 2
421
422    def test_timeout(self):
423        queue = TestEventQueue([])
424        self.assertRaises(TimeoutError, queue.expect, 'foo')
425
426    def test_demand(self):
427        queue = TestEventQueue([Event('foo'), Event('bar')])
428        foo = queue.demand('foo')
429        assert foo.type == 'foo'
430
431    def test_demand_fail(self):
432        queue = TestEventQueue([Event('foo'), Event('bar')])
433        self.assertRaises(RuntimeError, queue.demand, 'bar')
434
435def unwrap(x):
436    """Hack to unwrap D-Bus values, so that they're easier to read when
437    printed."""
438
439    if isinstance(x, list):
440        return list(map(unwrap, x))
441
442    if isinstance(x, tuple):
443        return tuple(map(unwrap, x))
444
445    if isinstance(x, dict):
446        return dict([(unwrap(k), unwrap(v)) for k, v in x.items()])
447
448    if isinstance(x, dbus.Boolean):
449        return bool(x)
450
451    for t in [str, str, int, int, float]:
452        if isinstance(x, t):
453            return t(x)
454
455    return x
456
457def call_async(test, proxy, method, *args, **kw):
458    """Call a D-Bus method asynchronously and generate an event for the
459    resulting method return/error."""
460
461    def reply_func(*ret):
462        test.handle_event(Event('dbus-return', method=method,
463            value=unwrap(ret)))
464
465    def error_func(err):
466        test.handle_event(Event('dbus-error', method=method, error=err,
467            name=err.get_dbus_name(), message=str(err)))
468
469    method_proxy = getattr(proxy, method)
470    kw.update({'reply_handler': reply_func, 'error_handler': error_func})
471    method_proxy(*args, **kw)
472
473def sync_dbus(bus, q, proxy):
474    # Dummy D-Bus method call. We can't use DBus.Peer.Ping() because libdbus
475    # replies to that message immediately, rather than handing it up to
476    # dbus-glib and thence the application, which means that Ping()ing the
477    # application doesn't ensure that it's processed all D-Bus messages prior
478    # to our ping.
479    call_async(q, dbus.Interface(proxy, 'org.freedesktop.Telepathy.Tests'),
480        'DummySyncDBus')
481    q.expect('dbus-error', method='DummySyncDBus')
482
483class ProxyWrapper:
484    def __init__(self, object, default, others):
485        self.object = object
486        self.default_interface = dbus.Interface(object, default)
487        self.Properties = dbus.Interface(object, dbus.PROPERTIES_IFACE)
488        self.TpProperties = \
489            dbus.Interface(object, tp_name_prefix + '.Properties')
490        self.interfaces = dict([
491            (name, dbus.Interface(object, iface))
492            for name, iface in others.items()])
493
494    def __getattr__(self, name):
495        if name in self.interfaces:
496            return self.interfaces[name]
497
498        if name in self.object.__dict__:
499            return getattr(self.object, name)
500
501        return getattr(self.default_interface, name)
502
503def wrap_channel(chan, type_, extra=None):
504    interfaces = {
505        type_: tp_name_prefix + '.Channel.Type.' + type_,
506        'Group': tp_name_prefix + '.Channel.Interface.Group',
507        }
508
509    if extra:
510        interfaces.update(dict([
511            (name, tp_name_prefix + '.Channel.Interface.' + name)
512            for name in extra]))
513
514    return ProxyWrapper(chan, tp_name_prefix + '.Channel', interfaces)
515
516def make_connection(bus, event_func, name, proto, params):
517    cm = bus.get_object(
518        tp_name_prefix + '.ConnectionManager.%s' % name,
519        tp_path_prefix + '/ConnectionManager/%s' % name)
520    cm_iface = dbus.Interface(cm, tp_name_prefix + '.ConnectionManager')
521
522    connection_name, connection_path = cm_iface.RequestConnection(
523        proto, params)
524    conn = wrap_connection(bus.get_object(connection_name, connection_path))
525
526    return conn
527
528def make_channel_proxy(conn, path, iface):
529    bus = dbus.SessionBus()
530    chan = bus.get_object(conn.object.bus_name, path)
531    chan = dbus.Interface(chan, tp_name_prefix + '.' + iface)
532    return chan
533
534# block_reading can be used if the test want to choose when we start to read
535# data from the socket.
536class EventProtocol(Protocol):
537    def __init__(self, queue=None, block_reading=False):
538        self.queue = queue
539        self.block_reading = block_reading
540
541    def dataReceived(self, data):
542        if self.queue is not None:
543            self.queue.handle_event(Event('socket-data', protocol=self,
544                data=data))
545
546    def sendData(self, data):
547        self.transport.write(data)
548
549    def connectionMade(self):
550        if self.block_reading:
551            self.transport.stopReading()
552
553    def connectionLost(self, reason=None):
554        if self.queue is not None:
555            self.queue.handle_event(Event('socket-disconnected', protocol=self))
556
557class EventProtocolFactory(Factory):
558    def __init__(self, queue, block_reading=False):
559        self.queue = queue
560        self.block_reading = block_reading
561
562    def _create_protocol(self):
563        return EventProtocol(self.queue, self.block_reading)
564
565    def buildProtocol(self, addr):
566        proto = self._create_protocol()
567        self.queue.handle_event(Event('socket-connected', protocol=proto))
568        return proto
569
570class EventProtocolClientFactory(EventProtocolFactory, ClientFactory):
571    pass
572
573def watch_tube_signals(q, tube):
574    def got_signal_cb(*args, **kwargs):
575        q.handle_event(Event('tube-signal',
576            path=kwargs['path'],
577            signal=kwargs['member'],
578            args=list(map(unwrap, args)),
579            tube=tube))
580
581    tube.add_signal_receiver(got_signal_cb,
582        path_keyword='path', member_keyword='member',
583        byte_arrays=True)
584
585def pretty(x):
586    return pprint.pformat(unwrap(x))
587
588def assertEquals(expected, value):
589    if expected != value:
590        raise AssertionError(
591            "expected:\n%s\ngot:\n%s" % (pretty(expected), pretty(value)))
592
593def assertNotEquals(expected, value):
594    if expected == value:
595        raise AssertionError(
596            "expected something other than:\n%s" % pretty(value))
597
598def assertContains(element, value):
599    if element not in value:
600        raise AssertionError(
601            "expected:\n%s\nin:\n%s" % (pretty(element), pretty(value)))
602
603def assertDoesNotContain(element, value):
604    if element in value:
605        raise AssertionError(
606            "expected:\n%s\nnot in:\n%s" % (pretty(element), pretty(value)))
607
608def assertLength(length, value):
609    if len(value) != length:
610        raise AssertionError("expected: length %d, got length %d:\n%s" % (
611            length, len(value), pretty(value)))
612
613def assertFlagsSet(flags, value):
614    masked = value & flags
615    if masked != flags:
616        raise AssertionError(
617            "expected flags %u, of which only %u are set in %u" % (
618            flags, masked, value))
619
620def assertFlagsUnset(flags, value):
621    masked = value & flags
622    if masked != 0:
623        raise AssertionError(
624            "expected none of flags %u, but %u are set in %u" % (
625            flags, masked, value))
626
627def assertSameSets(expected, value):
628    exp_set = set(expected)
629    val_set = set(value)
630
631    if exp_set != val_set:
632        raise AssertionError(
633            "expected contents:\n%s\ngot:\n%s" % (
634                pretty(exp_set), pretty(val_set)))
635
636
637def install_colourer():
638    def red(s):
639        return '\x1b[31m%s\x1b[0m' % s
640
641    def green(s):
642        return '\x1b[32m%s\x1b[0m' % s
643
644    patterns = {
645        'handled': green,
646        'not handled': red,
647        }
648
649    class Colourer:
650        def __init__(self, fh, patterns):
651            self.fh = fh
652            self.patterns = patterns
653
654        def write(self, s):
655            f = self.patterns.get(s, lambda x: x)
656            self.fh.write(f(s))
657
658    sys.stdout = Colourer(sys.stdout, patterns)
659    return sys.stdout
660
661
662
663if __name__ == '__main__':
664    unittest.main()
665
666