1
2"""
3Infrastructure code for testing connection managers.
4"""
5
6from twisted.internet import glib2reactor
7from twisted.internet.protocol import Protocol, Factory, ClientFactory
8glib2reactor.install()
9import sys
10import time
11
12import pprint
13import unittest
14
15import dbus.glib
16
17from twisted.internet import reactor
18
19import constants as cs
20
21tp_name_prefix = 'org.freedesktop.Telepathy'
22tp_path_prefix = '/org/freedesktop/Telepathy'
23
24class DictionarySupersetOf (object):
25    """Utility class for expecting "a dictionary with at least these keys"."""
26    def __init__(self, dictionary):
27        self._dictionary = dictionary
28    def __repr__(self):
29        return "DictionarySupersetOf(%s)" % self._dictionary
30    def __eq__(self, other):
31        """would like to just do:
32        return set(other.items()).issuperset(self._dictionary.items())
33        but it turns out that this doesn't work if you have another dict
34        nested in the values of your dicts"""
35        try:
36            for k,v in self._dictionary.items():
37                if k not in other or other[k] != v:
38                    return False
39            return True
40        except TypeError: # other is not iterable
41            return False
42
43class Event:
44    def __init__(self, type, **kw):
45        self.__dict__.update(kw)
46        self.type = type
47        (self.subqueue, self.subtype) = type.split ("-", 1)
48
49def format_event(event):
50    ret = ['- type %s' % event.type]
51
52    for key in dir(event):
53        if key != 'type' and not key.startswith('_'):
54            ret.append('- %s: %s' % (
55                key, pprint.pformat(getattr(event, key))))
56
57            if key == 'error':
58                ret.append('%s' % getattr(event, key))
59
60    return ret
61
62class EventPattern:
63    def __init__(self, type, **properties):
64        self.type = type
65        self.predicate = None
66        if 'predicate' in properties:
67            self.predicate = properties['predicate']
68            del properties['predicate']
69        self.properties = properties
70        (self.subqueue, self.subtype) = type.split ("-", 1)
71
72    def __repr__(self):
73        properties = dict(self.properties)
74
75        if self.predicate is not None:
76            properties['predicate'] = self.predicate
77
78        return '%s(%r, **%r)' % (
79            self.__class__.__name__, self.type, properties)
80
81    def match(self, event):
82        if event.type != self.type:
83            return False
84
85        for key, value in self.properties.iteritems():
86            try:
87                if getattr(event, key) != value:
88                    return False
89            except AttributeError:
90                return False
91
92        if self.predicate is None or self.predicate(event):
93            return True
94
95        return False
96
97
98class TimeoutError(Exception):
99    pass
100
101class ForbiddenEventOccurred(Exception):
102    def __init__(self, event):
103        Exception.__init__(self)
104        self.event = event
105
106    def __str__(self):
107        return '\n' + '\n'.join(format_event(self.event))
108
109class BaseEventQueue:
110    """Abstract event queue base class.
111
112    Implement the wait() method to have something that works.
113    """
114
115    def __init__(self, timeout=None):
116        self.verbose = False
117        self.forbidden_events = set()
118        self.event_queues = {}
119
120        if timeout is None:
121            self.timeout = 5
122        else:
123            self.timeout = timeout
124
125    def log(self, s):
126        if self.verbose:
127            print s
128
129    def log_queues(self, queues):
130        self.log ("Waiting for event on: %s" % ", ".join(queues))
131
132    def log_event(self, event):
133        self.log('got event:')
134
135        if self.verbose:
136            map(self.log, format_event(event))
137
138    def forbid_events(self, patterns):
139        """
140        Add patterns (an iterable of EventPattern) to the set of forbidden
141        events. If a forbidden event occurs during an expect or expect_many,
142        the test will fail.
143        """
144        self.forbidden_events.update(set(patterns))
145
146    def unforbid_events(self, patterns):
147        """
148        Remove 'patterns' (an iterable of EventPattern) from the set of
149        forbidden events. These must be the same EventPattern pointers that
150        were passed to forbid_events.
151        """
152        self.forbidden_events.difference_update(set(patterns))
153
154    def _check_forbidden(self, event):
155        for e in self.forbidden_events:
156            if e.match(event):
157                raise ForbiddenEventOccurred(event)
158
159    def expect(self, type, **kw):
160        """
161        Waits for an event matching the supplied pattern to occur, and returns
162        it. For example, to await a D-Bus signal with particular arguments:
163
164            e = q.expect('dbus-signal', signal='Badgers', args=["foo", 42])
165        """
166        pattern = EventPattern(type, **kw)
167        t = time.time()
168
169        while True:
170            event = self.wait([pattern.subqueue])
171            self._check_forbidden(event)
172
173            if pattern.match(event):
174                self.log('handled, took %0.3f ms'
175                    % ((time.time() - t) * 1000.0) )
176                self.log('')
177                return event
178
179            self.log('not handled')
180            self.log('')
181
182    def expect_many(self, *patterns):
183        """
184        Waits for events matching all of the supplied EventPattern instances to
185        return, and returns a list of events in the same order as the patterns
186        they matched. After a pattern is successfully matched, it is not
187        considered for future events; if more than one unsatisfied pattern
188        matches an event, the first "wins".
189
190        Note that the expected events may occur in any order. If you're
191        expecting a series of events in a particular order, use repeated calls
192        to expect() instead.
193
194        This method is useful when you're awaiting a number of events which may
195        happen in any order. For instance, in telepathy-gabble, calling a D-Bus
196        method often causes a value to be returned immediately, as well as a
197        query to be sent to the server. Since these events may reach the test
198        in either order, the following is incorrect and will fail if the IQ
199        happens to reach the test first:
200
201            ret = q.expect('dbus-return', method='Foo')
202            query = q.expect('stream-iq', query_ns=ns.FOO)
203
204        The following would be correct:
205
206            ret, query = q.expect_many(
207                EventPattern('dbus-return', method='Foo'),
208                EventPattern('stream-iq', query_ns=ns.FOO),
209            )
210        """
211        ret = [None] * len(patterns)
212        t = time.time()
213
214        while None in ret:
215            try:
216                queues = set()
217                for i, pattern in enumerate(patterns):
218                    if ret[i] is None:
219                        queues.add(pattern.subqueue)
220                event = self.wait(queues)
221            except TimeoutError:
222                self.log('timeout')
223                self.log('still expecting:')
224                for i, pattern in enumerate(patterns):
225                    if ret[i] is None:
226                        self.log(' - %r' % pattern)
227                raise
228            self._check_forbidden(event)
229
230            for i, pattern in enumerate(patterns):
231                if ret[i] is None and pattern.match(event):
232                    self.log('handled, took %0.3f ms'
233                        % ((time.time() - t) * 1000.0) )
234                    self.log('')
235                    ret[i] = event
236                    break
237            else:
238                self.log('not handled')
239                self.log('')
240
241        return ret
242
243    def demand(self, type, **kw):
244        pattern = EventPattern(type, **kw)
245
246        event = self.wait([pattern.subqueue])
247
248        if pattern.match(event):
249            self.log('handled')
250            self.log('')
251            return event
252
253        self.log('not handled')
254        raise RuntimeError('expected %r, got %r' % (pattern, event))
255
256    def queues_available(self, queues):
257        if queues == None:
258            return self.event_queues.keys()
259        else:
260            available = self.event_queues.keys()
261            return filter(lambda x: x in available, queues)
262
263
264    def pop_next(self, queue):
265        events = self.event_queues[queue]
266        e = events.pop(0)
267        if not events:
268           self.event_queues.pop (queue)
269        return e
270
271    def append(self, event):
272        self.log ("Adding to queue")
273        self.log_event (event)
274        self.event_queues[event.subqueue] = \
275            self.event_queues.get(event.subqueue, []) + [event]
276
277class IteratingEventQueue(BaseEventQueue):
278    """Event queue that works by iterating the Twisted reactor."""
279
280    def __init__(self, timeout=None):
281        BaseEventQueue.__init__(self, timeout)
282
283    def wait(self, queues=None):
284        stop = [False]
285
286        def later():
287            stop[0] = True
288
289        delayed_call = reactor.callLater(self.timeout, later)
290
291        self.log_queues(queues)
292
293        qa = self.queues_available(queues)
294        while not qa and (not stop[0]):
295            reactor.iterate(0.01)
296            qa = self.queues_available(queues)
297
298        if qa:
299            delayed_call.cancel()
300            e = self.pop_next (qa[0])
301            self.log_event (e)
302            return e
303        else:
304            raise TimeoutError
305
306class TestEventQueue(BaseEventQueue):
307    def __init__(self, events):
308        BaseEventQueue.__init__(self)
309        for e in events:
310            self.append (e)
311
312    def wait(self, queues = None):
313        qa = self.queues_available(queues)
314
315        if qa:
316            return self.pop_next (qa[0])
317        else:
318            raise TimeoutError
319
320class EventQueueTest(unittest.TestCase):
321    def test_expect(self):
322        queue = TestEventQueue([Event('test-foo'), Event('test-bar')])
323        assert queue.expect('test-foo').type == 'test-foo'
324        assert queue.expect('test-bar').type == 'test-bar'
325
326    def test_expect_many(self):
327        queue = TestEventQueue([Event('test-foo'),
328            Event('test-bar')])
329        bar, foo = queue.expect_many(
330            EventPattern('test-bar'),
331            EventPattern('test-foo'))
332        assert bar.type == 'test-bar'
333        assert foo.type == 'test-foo'
334
335    def test_expect_many2(self):
336        # Test that events are only matched against patterns that haven't yet
337        # been matched. This tests a regression.
338        queue = TestEventQueue([Event('test-foo', x=1), Event('test-foo', x=2)])
339        foo1, foo2 = queue.expect_many(
340            EventPattern('test-foo'),
341            EventPattern('test-foo'))
342        assert foo1.type == 'test-foo' and foo1.x == 1
343        assert foo2.type == 'test-foo' and foo2.x == 2
344
345    def test_expect_queueing(self):
346        queue = TestEventQueue([Event('foo-test', x=1),
347            Event('foo-test', x=2)])
348
349        queue.append(Event('bar-test', x=1))
350        queue.append(Event('bar-test', x=2))
351
352        queue.append(Event('baz-test', x=1))
353        queue.append(Event('baz-test', x=2))
354
355        for x in xrange(1,2):
356            e = queue.expect ('baz-test')
357            assertEquals (x, e.x)
358
359            e = queue.expect ('bar-test')
360            assertEquals (x, e.x)
361
362            e = queue.expect ('foo-test')
363            assertEquals (x, e.x)
364
365    def test_timeout(self):
366        queue = TestEventQueue([])
367        self.assertRaises(TimeoutError, queue.expect, 'test-foo')
368
369    def test_demand(self):
370        queue = TestEventQueue([Event('test-foo'), Event('test-bar')])
371        foo = queue.demand('test-foo')
372        assert foo.type == 'test-foo'
373
374    def test_demand_fail(self):
375        queue = TestEventQueue([Event('test-foo'), Event('test-bar')])
376        self.assertRaises(RuntimeError, queue.demand, 'test-bar')
377
378def unwrap(x):
379    """Hack to unwrap D-Bus values, so that they're easier to read when
380    printed."""
381
382    if isinstance(x, list):
383        return map(unwrap, x)
384
385    if isinstance(x, tuple):
386        return tuple(map(unwrap, x))
387
388    if isinstance(x, dict):
389        return dict([(unwrap(k), unwrap(v)) for k, v in x.iteritems()])
390
391    if isinstance(x, dbus.Boolean):
392        return bool(x)
393
394    for t in [unicode, str, long, int, float]:
395        if isinstance(x, t):
396            return t(x)
397
398    return x
399
400def call_async(test, proxy, method, *args, **kw):
401    """Call a D-Bus method asynchronously and generate an event for the
402    resulting method return/error."""
403
404    def reply_func(*ret):
405        test.append(Event('dbus-return', method=method,
406            value=unwrap(ret)))
407
408    def error_func(err):
409        test.append(Event('dbus-error', method=method, error=err,
410            name=err.get_dbus_name(), message=str(err)))
411
412    method_proxy = getattr(proxy, method)
413    kw.update({'reply_handler': reply_func, 'error_handler': error_func})
414    method_proxy(*args, **kw)
415
416def sync_dbus(bus, q, conn):
417    # Dummy D-Bus method call
418    # This won't do the right thing unless the proxy has a unique name.
419    assert conn.object.bus_name.startswith(':')
420    root_object = bus.get_object(conn.object.bus_name, '/')
421    call_async(
422        q, dbus.Interface(root_object, 'org.freedesktop.Telepathy.Tests'), 'DummySyncDBus')
423    q.expect('dbus-error', method='DummySyncDBus')
424
425class ProxyWrapper:
426    def __init__(self, object, default, others):
427        self.object = object
428        self.default_interface = dbus.Interface(object, default)
429        self.Properties = dbus.Interface(object, dbus.PROPERTIES_IFACE)
430        self.TpProperties = \
431            dbus.Interface(object, tp_name_prefix + '.Properties')
432        self.interfaces = dict([
433            (name, dbus.Interface(object, iface))
434            for name, iface in others.iteritems()])
435
436    def __getattr__(self, name):
437        if name in self.interfaces:
438            return self.interfaces[name]
439
440        if name in self.object.__dict__:
441            return getattr(self.object, name)
442
443        return getattr(self.default_interface, name)
444
445def wrap_connection(conn):
446    return ProxyWrapper(conn, tp_name_prefix + '.Connection',
447        dict([
448            (name, tp_name_prefix + '.Connection.Interface.' + name)
449            for name in ['Aliasing', 'Avatars', 'Capabilities', 'Contacts',
450              'Presence', 'SimplePresence', 'Requests']] +
451        [('Peer', 'org.freedesktop.DBus.Peer'),
452         ('ContactCapabilities', cs.CONN_IFACE_CONTACT_CAPS),
453         ('ContactInfo', cs.CONN_IFACE_CONTACT_INFO),
454         ('Location', cs.CONN_IFACE_LOCATION),
455         ('Future', tp_name_prefix + '.Connection.FUTURE'),
456         ('MailNotification', cs.CONN_IFACE_MAIL_NOTIFICATION),
457         ('ContactList', cs.CONN_IFACE_CONTACT_LIST),
458         ('ContactGroups', cs.CONN_IFACE_CONTACT_GROUPS),
459         ('PowerSaving', cs.CONN_IFACE_POWER_SAVING),
460        ]))
461
462def wrap_channel(chan, type_, extra=None):
463    interfaces = {
464        type_: tp_name_prefix + '.Channel.Type.' + type_,
465        'Group': tp_name_prefix + '.Channel.Interface.Group',
466        }
467
468    if extra:
469        interfaces.update(dict([
470            (name, tp_name_prefix + '.Channel.Interface.' + name)
471            for name in extra]))
472
473    return ProxyWrapper(chan, tp_name_prefix + '.Channel', interfaces)
474
475def make_connection(bus, event_func, name, proto, params):
476    cm = bus.get_object(
477        tp_name_prefix + '.ConnectionManager.%s' % name,
478        tp_path_prefix + '/ConnectionManager/%s' % name)
479    cm_iface = dbus.Interface(cm, tp_name_prefix + '.ConnectionManager')
480
481    connection_name, connection_path = cm_iface.RequestConnection(
482        proto, params)
483    conn = wrap_connection(bus.get_object(connection_name, connection_path))
484
485    return conn
486
487def make_channel_proxy(conn, path, iface):
488    bus = dbus.SessionBus()
489    chan = bus.get_object(conn.object.bus_name, path)
490    chan = dbus.Interface(chan, tp_name_prefix + '.' + iface)
491    return chan
492
493# block_reading can be used if the test want to choose when we start to read
494# data from the socket.
495class EventProtocol(Protocol):
496    def __init__(self, queue=None, block_reading=False):
497        self.queue = queue
498        self.block_reading = block_reading
499
500    def dataReceived(self, data):
501        if self.queue is not None:
502            self.queue.append(Event('socket-data', protocol=self,
503                data=data))
504
505    def sendData(self, data):
506        self.transport.write(data)
507
508    def connectionMade(self):
509        if self.block_reading:
510            self.transport.stopReading()
511
512    def connectionLost(self, reason=None):
513        if self.queue is not None:
514            self.queue.append(Event('socket-disconnected', protocol=self))
515
516class EventProtocolFactory(Factory):
517    def __init__(self, queue, block_reading=False):
518        self.queue = queue
519        self.block_reading = block_reading
520
521    def _create_protocol(self):
522        return EventProtocol(self.queue, self.block_reading)
523
524    def buildProtocol(self, addr):
525        proto = self._create_protocol()
526        self.queue.append(Event('socket-connected', protocol=proto))
527        return proto
528
529class EventProtocolClientFactory(EventProtocolFactory, ClientFactory):
530    pass
531
532def watch_tube_signals(q, tube):
533    def got_signal_cb(*args, **kwargs):
534        q.append(Event('tube-signal',
535            path=kwargs['path'],
536            signal=kwargs['member'],
537            args=map(unwrap, args),
538            tube=tube))
539
540    tube.add_signal_receiver(got_signal_cb,
541        path_keyword='path', member_keyword='member',
542        byte_arrays=True)
543
544def pretty(x):
545    return pprint.pformat(unwrap(x))
546
547def assertEquals(expected, value):
548    if expected != value:
549        raise AssertionError(
550            "expected:\n%s\ngot:\n%s" % (pretty(expected), pretty(value)))
551
552def assertSameSets(expected, value):
553    exp_set = set(expected)
554    val_set = set(value)
555
556    if exp_set != val_set:
557        raise AssertionError(
558            "expected contents:\n%s\ngot:\n%s" % (
559                pretty(exp_set), pretty(val_set)))
560
561def assertNotEquals(expected, value):
562    if expected == value:
563        raise AssertionError(
564            "expected something other than:\n%s" % pretty(value))
565
566def assertContains(element, value):
567    if element not in value:
568        raise AssertionError(
569            "expected:\n%s\nin:\n%s" % (pretty(element), pretty(value)))
570
571def assertDoesNotContain(element, value):
572    if element in value:
573        raise AssertionError(
574            "expected:\n%s\nnot in:\n%s" % (pretty(element), pretty(value)))
575
576def assertLength(length, value):
577    if len(value) != length:
578        raise AssertionError("expected: length %d, got length %d:\n%s" % (
579            length, len(value), pretty(value)))
580
581def assertFlagsSet(flags, value):
582    masked = value & flags
583    if masked != flags:
584        raise AssertionError(
585            "expected flags %u, of which only %u are set in %u" % (
586            flags, masked, value))
587
588def assertFlagsUnset(flags, value):
589    masked = value & flags
590    if masked != 0:
591        raise AssertionError(
592            "expected none of flags %u, but %u are set in %u" % (
593            flags, masked, value))
594
595def assertDBusError(name, error):
596    if error.get_dbus_name() != name:
597        raise AssertionError(
598            "expected DBus error named:\n  %s\ngot:\n  %s\n(with message: %s)"
599            % (name, error.get_dbus_name(), error.message))
600
601def install_colourer():
602    def red(s):
603        return '\x1b[31m%s\x1b[0m' % s
604
605    def green(s):
606        return '\x1b[32m%s\x1b[0m' % s
607
608    patterns = {
609        'handled': green,
610        'not handled': red,
611        }
612
613    class Colourer:
614        def __init__(self, fh, patterns):
615            self.fh = fh
616            self.patterns = patterns
617
618        def write(self, s):
619            for p, f in self.patterns.items():
620                if s.startswith(p):
621                    self.fh.write(f(p) + s[len(p):])
622                    return
623
624            self.fh.write(s)
625
626    sys.stdout = Colourer(sys.stdout, patterns)
627    return sys.stdout
628
629if __name__ == '__main__':
630    unittest.main()
631
632