1# -*- test-case-name: foolscap.test.test_pb -*-
2
3from __future__ import print_function
4import six
5import time
6from zope.interface import implementer, implementer_only, implementedBy, Interface
7from twisted.python import log
8from twisted.internet import defer, reactor, task, protocol
9from twisted.application import internet
10from twisted.trial import unittest
11from foolscap import broker, eventual, negotiate
12from foolscap.api import Tub, Referenceable, RemoteInterface, \
13     eventually, fireEventually, flushEventualQueue
14from foolscap.remoteinterface import getRemoteInterface, RemoteMethodSchema, \
15     UnconstrainedMethod
16from foolscap.schema import Any, SetOf, DictOf, ListOf, TupleOf, \
17     NumberConstraint, ByteStringConstraint, IntegerConstraint, \
18     UnicodeConstraint, ChoiceOf
19from foolscap.referenceable import TubRef
20from foolscap.util import allocate_tcp_port, long_type
21
22from twisted.python import failure
23from twisted.internet.main import CONNECTION_DONE
24
25def getRemoteInterfaceName(obj):
26    i = getRemoteInterface(obj)
27    return i.__remote_name__
28
29class Loopback:
30    # The transport's promise is that write() can be treated as a
31    # synchronous, isolated function call: specifically, the Protocol's
32    # dataReceived() and connectionLost() methods shall not be called during
33    # a call to write().
34
35    connected = True
36    def write(self, data):
37        eventually(self._write, data)
38
39    def _write(self, data):
40        if not self.connected:
41            return
42        try:
43            # isolate exceptions: if one occurred on a regular TCP transport,
44            # they would hang up, so duplicate that here.
45            self.peer.dataReceived(data)
46        except:
47            f = failure.Failure()
48            log.err(f)
49            print("Loopback.write exception:", f)
50            self.loseConnection(f)
51
52    def loseConnection(self, why=failure.Failure(CONNECTION_DONE)):
53        assert isinstance(why, failure.Failure), why
54        if self.connected:
55            self.connected = False
56            # this one is slightly weird because 'why' is a Failure
57            eventually(self._loseConnection, why)
58
59    def _loseConnection(self, why):
60        assert isinstance(why, failure.Failure), why
61        self.protocol.connectionLost(why)
62        self.peer.connectionLost(why)
63
64    def flush(self):
65        self.connected = False
66        return fireEventually()
67
68    def getPeer(self):
69        return broker.LoopbackAddress()
70    def getHost(self):
71        return broker.LoopbackAddress()
72
73MegaSchema1 = DictOf(ByteStringConstraint(),
74                     ListOf(TupleOf(SetOf(int, maxLength=10, mutable=True),
75                                    six.binary_type, bool, int, long_type, float, None,
76                                    UnicodeConstraint(),
77                                    ByteStringConstraint(),
78                                    Any(), NumberConstraint(),
79                                    IntegerConstraint(),
80                                    ByteStringConstraint(maxLength=100,
81                                                         minLength=90),
82                                    ),
83                            maxLength=20),
84                     maxKeys=5)
85# containers should convert their arguments into schemas
86MegaSchema2 = TupleOf(SetOf(int),
87                      ListOf(int),
88                      DictOf(int, str),
89                      )
90MegaSchema3 = ListOf(TupleOf(int,int))
91
92
93class RIHelper(RemoteInterface):
94    def set(obj=Any()): return bool
95    def set2(obj1=Any(), obj2=Any()): return bool
96    def append(obj=Any()): return Any()
97    def get(): return Any()
98    def echo(obj=Any()): return Any()
99    def defer(obj=Any()): return Any()
100    def hang(): return Any()
101    # test one of everything
102    def megaschema(obj1=MegaSchema1, obj2=MegaSchema2): return None
103    def mega3(obj1=MegaSchema3): return None
104    def choice1(obj1=ChoiceOf(ByteStringConstraint(2000), int)): return None
105
106@implementer(RIHelper)
107class HelperTarget(Referenceable):
108    d = None
109    def __init__(self, name="unnamed"):
110        self.name = name
111    def __repr__(self):
112        return "<HelperTarget %s>" % self.name
113    def waitfor(self):
114        self.d = defer.Deferred()
115        return self.d
116
117    def remote_set(self, obj):
118        self.obj = obj
119        if self.d:
120            self.d.callback(obj)
121        return True
122    def remote_set2(self, obj1, obj2):
123        self.obj1 = obj1
124        self.obj2 = obj2
125        return True
126
127    def remote_append(self, obj):
128        self.calls.append(obj)
129
130    def remote_get(self):
131        return self.obj
132
133    def remote_echo(self, obj):
134        self.obj = obj
135        return obj
136
137    def remote_defer(self, obj):
138        return fireEventually(obj)
139
140    def remote_hang(self):
141        self.d = defer.Deferred()
142        return self.d
143
144    def remote_megaschema(self, obj1, obj2):
145        self.obj1 = obj1
146        self.obj2 = obj2
147        return None
148
149    def remote_mega3(self, obj):
150        self.obj = obj
151        return None
152
153    def remote_choice1(self, obj):
154        self.obj = obj
155        return None
156
157class TimeoutError(Exception):
158    pass
159
160class PollComplete(Exception):
161    pass
162
163class PollMixin:
164
165    def poll(self, check_f, pollinterval=0.01, timeout=None):
166        # Return a Deferred, then call check_f periodically until it returns
167        # True, at which point the Deferred will fire.. If check_f raises an
168        # exception, the Deferred will errback. If the check_f does not
169        # indicate success within timeout= seconds, the Deferred will
170        # errback. If timeout=None, no timeout will be enforced, and the loop
171        # will poll forever (or really until Trial times out).
172        cutoff = None
173        if timeout is not None:
174            cutoff = time.time() + timeout
175        lc = task.LoopingCall(self._poll, check_f, cutoff)
176        d = lc.start(pollinterval)
177        def _convert_done(f):
178            f.trap(PollComplete)
179            return None
180        d.addErrback(_convert_done)
181        return d
182
183    def _poll(self, check_f, cutoff):
184        if cutoff is not None and time.time() > cutoff:
185            raise TimeoutError()
186        if check_f():
187            raise PollComplete()
188
189class StallMixin:
190    def stall(self, res, timeout):
191        d = defer.Deferred()
192        reactor.callLater(timeout, d.callback, res)
193        return d
194
195class TargetMixin(PollMixin, StallMixin):
196
197    def setUp(self):
198        self.loopbacks = []
199
200    def setupBrokers(self):
201
202        self.targetBroker = broker.Broker(TubRef("targetBroker"))
203        self.callingBroker = broker.Broker(TubRef("callingBroker"))
204
205        t1 = Loopback()
206        t1.peer = self.callingBroker
207        t1.protocol = self.targetBroker
208        self.targetBroker.transport = t1
209        self.loopbacks.append(t1)
210
211        t2 = Loopback()
212        t2.peer = self.targetBroker
213        t2.protocol = self.callingBroker
214        self.callingBroker.transport = t2
215        self.loopbacks.append(t2)
216
217        self.targetBroker.connectionMade()
218        self.callingBroker.connectionMade()
219
220    def tearDown(self):
221        # returns a Deferred which fires when the Loopbacks are drained
222        dl = [l.flush() for l in self.loopbacks]
223        d = defer.DeferredList(dl)
224        d.addCallback(flushEventualQueue)
225        return d
226
227    def setupTarget(self, target, txInterfaces=False):
228        # txInterfaces controls what interfaces the sender uses
229        #  False: sender doesn't know about any interfaces
230        #  True: sender gets the actual interface list from the target
231        #  (list): sender uses an artificial interface list
232        puid = target.processUniqueID()
233        tracker = self.targetBroker.getTrackerForMyReference(puid, target)
234        tracker.send()
235        clid = tracker.clid
236        if txInterfaces:
237            iname = getRemoteInterfaceName(target)
238        else:
239            iname = None
240        rtracker = self.callingBroker.getTrackerForYourReference(clid, iname)
241        rr = rtracker.getRef()
242        return rr, target
243
244
245
246class RIMyTarget(RemoteInterface):
247    # method constraints can be declared directly:
248    add1 = RemoteMethodSchema(_response=int, a=int, b=int)
249    free = UnconstrainedMethod()
250
251    # or through their function definitions:
252    def add(a=int, b=int): return int
253    #add = schema.callable(add) # the metaclass makes this unnecessary
254    # but it could be used for adding options or something
255    def join(a=bytes, b=bytes, c=int): return bytes
256    def getName(): return bytes
257    disputed = RemoteMethodSchema(_response=int, a=int)
258    def fail(): return str  # actually raises an exception
259    def failstring(): return str # raises a string exception
260
261class RIMyTarget2(RemoteInterface):
262    __remote_name__ = "RIMyTargetInterface2"
263    sub = RemoteMethodSchema(_response=int, a=int, b=int)
264
265# For some tests, we want the two sides of the connection to disagree about
266# the contents of the RemoteInterface they are using. This is remarkably
267# difficult to accomplish within a single process. We do it by creating
268# something that behaves just barely enough like a RemoteInterface to work.
269class FakeTarget(dict):
270    pass
271RIMyTarget3 = FakeTarget()
272RIMyTarget3.__remote_name__ = RIMyTarget.__remote_name__
273
274RIMyTarget3['disputed'] = RemoteMethodSchema(_response=int, a=str)
275RIMyTarget3['disputed'].name = "disputed"
276RIMyTarget3['disputed'].interface = RIMyTarget3
277
278RIMyTarget3['disputed2'] = RemoteMethodSchema(_response=str, a=int)
279RIMyTarget3['disputed2'].name = "disputed"
280RIMyTarget3['disputed2'].interface = RIMyTarget3
281
282RIMyTarget3['sub'] = RemoteMethodSchema(_response=int, a=int, b=int)
283RIMyTarget3['sub'].name = "sub"
284RIMyTarget3['sub'].interface = RIMyTarget3
285
286@implementer(RIMyTarget)
287class Target(Referenceable):
288    def __init__(self, name=None):
289        self.calls = []
290        self.name = name
291    def getMethodSchema(self, methodname):
292        return None
293    def remote_add(self, a, b):
294        self.calls.append((a,b))
295        return a+b
296    remote_add1 = remote_add
297    def remote_free(self, *args, **kwargs):
298        self.calls.append((args, kwargs))
299        return "bird"
300    def remote_getName(self):
301        return self.name
302    def remote_disputed(self, a):
303        return 24
304    def remote_fail(self):
305        raise ValueError("you asked me to fail")
306    def remote_fail_remotely(self, target):
307        return target.callRemote("fail")
308
309    def remote_failstring(self):
310        raise "string exceptions are annoying"
311
312    def remote_with_f(self, f):
313        return f
314
315@implementer_only(implementedBy(Referenceable))
316class TargetWithoutInterfaces(Target):
317    # undeclare the RIMyTarget interface
318    pass
319
320@implementer(RIMyTarget)
321class BrokenTarget(Referenceable):
322    def remote_add(self, a, b):
323        return "error"
324
325
326class IFoo(Interface):
327    # non-remote Interface
328    pass
329
330@implementer(IFoo)
331class Foo(Referenceable):
332    pass
333
334class RIDummy(RemoteInterface):
335    pass
336
337class RITypes(RemoteInterface):
338    def returns_none(work=bool): return None
339    def takes_remoteinterface(a=RIDummy): return str
340    def returns_remoteinterface(work=int): return RIDummy
341    def takes_interface(a=IFoo): return str
342    def returns_interface(work=bool): return IFoo
343
344@implementer(RIDummy)
345class DummyTarget(Referenceable):
346    pass
347
348@implementer(RITypes)
349class TypesTarget(Referenceable):
350    def remote_returns_none(self, work):
351        if work:
352            return None
353        return "not None"
354
355    def remote_takes_remoteinterface(self, a):
356        # TODO: really, I want to just be able to say:
357        #   if RIDummy.providedBy(a):
358        iface = a.tracker.interface
359        if iface and iface == RIDummy:
360            return "good"
361        raise RuntimeError("my argument (%s) should provide RIDummy, "
362                           "but doesn't" % a)
363
364    def remote_returns_remoteinterface(self, work):
365        if work == 1:
366            return DummyTarget()
367        if work == -1:
368            return TypesTarget()
369        return 15
370
371    def remote_takes_interface(self, a):
372        if IFoo.providedBy(a):
373            return "good"
374        raise RuntimeError("my argument (%s) should provide IFoo, but doesn't" % a)
375
376    def remote_returns_interface(self, work):
377        if work:
378            return Foo()
379        return "not implementor of IFoo"
380
381
382class ShouldFailMixin:
383
384    def shouldFail(self, expected_failure, which, substring,
385                   callable, *args, **kwargs):
386        assert substring is None or isinstance(substring, str)
387        d = defer.maybeDeferred(callable, *args, **kwargs)
388        def done(res):
389            if isinstance(res, failure.Failure):
390                if not res.check(expected_failure):
391                    self.fail("got failure %s, was expecting %s"
392                              % (res, expected_failure))
393                if substring:
394                    self.assertTrue(substring in str(res),
395                                    "%s: substring '%s' not in '%s'"
396                                    % (which, substring, str(res)))
397                # make the Failure available to a subsequent callback, but
398                # keep it from triggering an errback
399                return [res]
400            else:
401                self.fail("%s was supposed to raise %s, not get '%s'" %
402                          (which, expected_failure, res))
403        d.addBoth(done)
404        return d
405
406tubid_low = "3hemthez7rvgvyhjx2n5kdj7mcyar3yt"
407certData_low = \
408"""-----BEGIN CERTIFICATE-----
409MIIBnjCCAQcCAgCEMA0GCSqGSIb3DQEBBAUAMBcxFTATBgNVBAMUDG5ld3BiX3Ro
410aW5neTAeFw0wNjExMjYxODUxMTBaFw0wNzExMjYxODUxMTBaMBcxFTATBgNVBAMU
411DG5ld3BiX3RoaW5neTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA1DuK9NoF
412fiSreA8rVqYPAjNiUqFelAAYPgnJR92Jry1J/dPA3ieNcCazbjVeKUFjd6+C30XR
413APhajsAJFiJdnmgrtVILNrpZDC/vISKQoAmoT9hP/cMqFm8vmUG/+AXO76q63vfH
414UmabBVDNTlM8FJpbm9M26cFMrH45G840gA0CAwEAATANBgkqhkiG9w0BAQQFAAOB
415gQBCtjgBbF/s4w/16Y15lkTAO0xt8ZbtrvcsFPGTXeporonejnNaJ/aDbJt8Y6nY
416ypJ4+LTT3UQwwvqX5xEuJmFhmXGsghRGypbU7Zxw6QZRppBRqz8xMS+y82mMZRQp
417ezP+BiTvnoWXzDEP1233oYuELVgOVnHsj+rC017Ykfd7fw==
418-----END CERTIFICATE-----
419-----BEGIN RSA PRIVATE KEY-----
420MIICXQIBAAKBgQDUO4r02gV+JKt4DytWpg8CM2JSoV6UABg+CclH3YmvLUn908De
421J41wJrNuNV4pQWN3r4LfRdEA+FqOwAkWIl2eaCu1Ugs2ulkML+8hIpCgCahP2E/9
422wyoWby+ZQb/4Bc7vqrre98dSZpsFUM1OUzwUmlub0zbpwUysfjkbzjSADQIDAQAB
423AoGBAIvxTykw8dpBt8cMyZjzGoZq93Rg74pLnbCap1x52iXmiRmUHWLfVcYT3tDW
4244+X0NfBfjL5IvQ4UtTHXsqYjtvJfXWazYYa4INv5wKDBCd5a7s1YQ8R7mnhlBbRd
425nqZ6RpGuQbd3gTGZCkUdbHPSqdCPAjryH9mtWoQZIepcIcoJAkEA77gjO+MPID6v
426K6lf8SuFXHDOpaNOAiMlxVnmyQYQoF0PRVSpKOQf83An7R0S/jN3C7eZ6fPbZcyK
427SFVktHhYwwJBAOKlgndbSkVzkQCMcuErGZT1AxHNNHSaDo8X3C47UbP3nf60SkxI
428boqmpuPvEPUB9iPQdiNZGDU04+FUhe5Vtu8CQHDQHXS/hIzOMy2/BfG/Y4F/bSCy
429W7HRzKK1jlCoVAbEBL3B++HMieTMsV17Q0bx/WI8Q2jAZE3iFmm4Fi6APHUCQCMi
4305Yb7cBg0QlaDb4vY0q51DXTFC0zIVVl5qXjBWXk8+hFygdIxqHF2RIkxlr9k/nOu
4317aGtPkOBX5KfN+QrBaECQQCltPE9YjFoqPezfyvGZoWAKb8bWzo958U3uVBnCw2f
432Fs8AQDgI/9gOUXxXno51xQSdCnJLQJ8lThRUa6M7/F1B
433-----END RSA PRIVATE KEY-----
434"""
435
436tubid_high = "6cxxohyb5ysw6ftpwprbzffxrghbfopm"
437certData_high = \
438"""-----BEGIN CERTIFICATE-----
439MIIBnjCCAQcCAgCEMA0GCSqGSIb3DQEBBAUAMBcxFTATBgNVBAMUDG5ld3BiX3Ro
440aW5neTAeFw0wNjExMjYxODUxNDFaFw0wNzExMjYxODUxNDFaMBcxFTATBgNVBAMU
441DG5ld3BiX3RoaW5neTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEArfrebvt3
4428FE3kKoscY2J/8A4J6CUUUiM7/gl00UvGvvjfdaWbsj4w0o8W2tE0X8Zce3dScSl
443D6qVXy6AEc4Flqs0q02w9uNzcdDY6LF3NiK0Lq+JP4OjJeImUBe8wUU0RQxqf/oA
444GhgHEZhTp6aAdxBXZFOVDloiW6iqrKH/thcCAwEAATANBgkqhkiG9w0BAQQFAAOB
445gQBXi+edp3iz07wxcRztvXtTAjY/9gUwlfa6qSTg/cGqbF0OPa+sISBOFRnnC8qM
446ENexlkpiiD4Oyj+UtO5g2CMz0E62cTJTqz6PfexnmKIGwYjq5wZ2tzOrB9AmAzLv
447TQQ9CdcKBXLd2GCToh8hBvjyyFwj+yTSbq+VKLMFkBY8Rg==
448-----END CERTIFICATE-----
449-----BEGIN RSA PRIVATE KEY-----
450MIICXgIBAAKBgQCt+t5u+3fwUTeQqixxjYn/wDgnoJRRSIzv+CXTRS8a++N91pZu
451yPjDSjxba0TRfxlx7d1JxKUPqpVfLoARzgWWqzSrTbD243Nx0NjosXc2IrQur4k/
452g6Ml4iZQF7zBRTRFDGp/+gAaGAcRmFOnpoB3EFdkU5UOWiJbqKqsof+2FwIDAQAB
453AoGBAKrU3Vp+Y2u+Y+ARqKgrQai1tq36eAhEQ9dRgtqrYTCOyvcCIR5RCirAFvnx
454H1bSBUsgNBw+EZGLfzZBs5FICaUjBOQYBYzfxux6+jlGvdl7idfHs7zogyEYBqye
4550VkwzZ0mVXM2ujOD/z/ANkdEn2fGj/VwAYDlfvlyNZMckHp5AkEA5sc1VG3snWmG
456lz4967MMzJ7XNpZcTvLEspjpH7hFbnXUHIQ4wPYOP7dhnVvKX1FiOQ8+zXVYDDGB
457SK1ABzpc+wJBAMD+imwAhHNBbOb3cPYzOz6XRZaetvep3GfE2wKr1HXP8wchNXWj
458Ijq6fJinwPlDugHaeNnfb+Dydd+YEiDTSJUCQDGCk2Jlotmyhfl0lPw4EYrkmO9R
459GsSlOKXIQFtZwSuNg9AKXdKn9y6cPQjxZF1GrHfpWWPixNz40e+xm4bxcnkCQQCs
460+zkspqYQ/CJVPpHkSnUem83GvAl5IKmp5Nr8oPD0i+fjixN0ljyW8RG+bhXcFaVC
461BgTuG4QW1ptqRs5w14+lAkEAuAisTPUDsoUczywyoBbcFo3SVpFPNeumEXrj4MD/
462uP+TxgBi/hNYaR18mTbKD4mzVSjqyEeRC/emV3xUpUrdqg==
463-----END RSA PRIVATE KEY-----
464"""
465
466
467class BaseMixin(ShouldFailMixin):
468
469    def setUp(self):
470        self.connections = []
471        self.servers = []
472        self.services = []
473
474    def tearDown(self):
475        for c in self.connections:
476            if c.transport:
477                c.transport.loseConnection()
478        dl = []
479        for s in self.servers:
480            dl.append(defer.maybeDeferred(s.stopListening))
481        for s in self.services:
482            dl.append(defer.maybeDeferred(s.stopService))
483        d = defer.DeferredList(dl)
484        d.addCallback(flushEventualQueue)
485        return d
486
487    def stall(self, res, timeout):
488        d = defer.Deferred()
489        reactor.callLater(timeout, d.callback, res)
490        return d
491
492    def insert_turns(self, res, count):
493        d = eventual.fireEventually(res)
494        for i in range(count-1):
495            d.addCallback(eventual.fireEventually)
496        return d
497
498    def makeServer(self, options={}, listenerOptions={}):
499        self.tub = tub = Tub(_test_options=options)
500        tub.startService()
501        self.services.append(tub)
502        portnum = allocate_tcp_port()
503        tub.listenOn("tcp:%d:interface=127.0.0.1" % portnum,
504                     _test_options=listenerOptions)
505        tub.setLocation("127.0.0.1:%d" % portnum)
506        self.target = Target()
507        return tub.registerReference(self.target), portnum
508
509    def makeSpecificServer(self, certData,
510                           negotiationClass=negotiate.Negotiation):
511        self.tub = tub = Tub(certData=certData)
512        tub.negotiationClass = negotiationClass
513        tub.startService()
514        self.services.append(tub)
515        portnum = allocate_tcp_port()
516        tub.listenOn("tcp:%d:interface=127.0.0.1" % portnum)
517        tub.setLocation("127.0.0.1:%d" % portnum)
518        self.target = Target()
519        return tub.registerReference(self.target), portnum
520
521    def createSpecificServer(self, certData,
522                             negotiationClass=negotiate.Negotiation):
523        tub = Tub(certData=certData)
524        tub.negotiationClass = negotiationClass
525        tub.startService()
526        self.services.append(tub)
527        portnum = allocate_tcp_port()
528        tub.listenOn("tcp:%d:interface=127.0.0.1" % portnum)
529        tub.setLocation("127.0.0.1:%d" % portnum)
530        target = Target()
531        return tub, target, tub.registerReference(target), portnum
532
533    def makeNullServer(self):
534        f = protocol.Factory()
535        f.protocol = protocol.Protocol # discards everything
536        s = internet.TCPServer(0, f)
537        s.startService()
538        self.services.append(s)
539        portnum = s._port.getHost().port
540        return portnum
541
542    def makeHTTPServer(self):
543        try:
544            from twisted.web import server, resource, static
545        except ImportError:
546            raise unittest.SkipTest('this test needs twisted.web')
547        root = resource.Resource()
548        root.putChild("", static.Data("hello\n", "text/plain"))
549        s = internet.TCPServer(0, server.Site(root))
550        s.startService()
551        self.services.append(s)
552        portnum = s._port.getHost().port
553        return portnum
554
555    def connectClient(self, portnum):
556        tub = Tub()
557        tub.startService()
558        self.services.append(tub)
559        d = tub.getReference("pb://127.0.0.1:%d/hello" % portnum)
560        return d
561
562class MakeTubsMixin:
563    def makeTubs(self, numTubs, mangleLocation=None, start=True):
564        self.services = []
565        self.tub_ports = []
566        for i in range(numTubs):
567            t = Tub()
568            if start:
569                t.startService()
570            self.services.append(t)
571            portnum = allocate_tcp_port()
572            self.tub_ports.append(portnum)
573            t.listenOn("tcp:%d:interface=127.0.0.1" % portnum)
574            location = "tcp:127.0.0.1:%d" % portnum
575            if mangleLocation:
576                location = mangleLocation(portnum)
577            t.setLocation(location)
578        return self.services
579