1
2
3# This module is responsible for the per-connection Broker object
4
5from __future__ import print_function
6import six
7import types, time
8from itertools import count
9
10from zope.interface import implementer
11from twisted.python import failure
12from twisted.internet import defer, error
13from twisted.internet import interfaces as twinterfaces
14from twisted.internet.protocol import connectionDone
15
16from foolscap import banana, tokens, ipb, vocab
17from foolscap import call, slicer, referenceable, copyable, remoteinterface
18from foolscap.constraint import Any
19from foolscap.tokens import Violation, BananaError
20from foolscap.ipb import DeadReferenceError, IBroker
21from foolscap.slicers.root import RootSlicer, RootUnslicer, ScopedRootSlicer
22from foolscap.eventual import eventually
23from foolscap.logging import log
24from functools import reduce
25
26LOST_CONNECTION_ERRORS = [error.ConnectionLost, error.ConnectionDone]
27try:
28    from OpenSSL import SSL
29    LOST_CONNECTION_ERRORS.append(SSL.Error)
30except ImportError:
31    pass
32
33PBTopRegistry = {
34    ("call",): call.CallUnslicer,
35    ("answer",): call.AnswerUnslicer,
36    ("error",): call.ErrorUnslicer,
37    }
38
39PBOpenRegistry = {
40    ('arguments',): call.ArgumentUnslicer,
41    ('my-reference',): referenceable.ReferenceUnslicer,
42    ('your-reference',): referenceable.YourReferenceUnslicer,
43    ('their-reference',): referenceable.TheirReferenceUnslicer,
44    # ('copyable', classname) is handled inline, through the CopyableRegistry
45    }
46
47class PBRootUnslicer(RootUnslicer):
48    # topRegistries defines what objects are allowed at the top-level
49    topRegistries = [PBTopRegistry]
50    # openRegistries defines what objects are allowed at the second level and
51    # below
52    openRegistries = [slicer.UnslicerRegistry, PBOpenRegistry]
53    logViolations = False
54
55    def checkToken(self, typebyte, size):
56        if typebyte != tokens.OPEN:
57            raise BananaError("top-level must be OPEN")
58
59    def openerCheckToken(self, typebyte, size, opentype):
60        if typebyte == tokens.STRING:
61            if len(opentype) == 0:
62                if size > self.maxIndexLength:
63                    why = "first opentype STRING token is too long, %d>%d" % \
64                          (size, self.maxIndexLength)
65                    raise Violation(why)
66            if opentype == ("copyable",):
67                # TODO: this is silly, of course (should pre-compute maxlen)
68                maxlen = reduce(max,
69                                [len(cname) \
70                                 for cname in list(copyable.CopyableRegistry.keys())]
71                                )
72                if size > maxlen:
73                    why = "copyable-classname token is too long, %d>%d" % \
74                          (size, maxlen)
75                    raise Violation(why)
76        elif typebyte == tokens.VOCAB:
77            return
78        else:
79            # TODO: hack for testing
80            raise Violation("index token 0x%02x not STRING or VOCAB" % \
81                              six.byte2int(typebyte))
82            raise BananaError("index token 0x%02x not STRING or VOCAB" % \
83                              six.byte2int(typebyte))
84
85    def open(self, opentype):
86        # used for lower-level objects, delegated up from childunslicer.open
87        child = RootUnslicer.open(self, opentype)
88        if child:
89            child.broker = self.broker
90        return child
91
92    def doOpen(self, opentype):
93        child = RootUnslicer.doOpen(self, opentype)
94        if child:
95            child.broker = self.broker
96        return child
97
98    def reportViolation(self, f):
99        if self.logViolations:
100            print("hey, something failed:", f)
101        return None # absorb the failure
102
103    def receiveChild(self, token, ready_deferred):
104        if isinstance(token, call.InboundDelivery):
105            self.broker.scheduleCall(token, ready_deferred)
106
107
108
109class PBRootSlicer(RootSlicer):
110    slicerTable = {types.MethodType: referenceable.CallableSlicer,
111                   types.FunctionType: referenceable.CallableSlicer,
112                   }
113    def registerRefID(self, refid, obj):
114        # references are never Broker-scoped: they're always scoped more
115        # narrowly, by the CallSlicer or the AnswerSlicer.
116        assert 0
117
118
119class RIBroker(remoteinterface.RemoteInterface):
120    def getReferenceByName(name=bytes):
121        """If I have published an object by that name, return a reference to
122        it."""
123        # return Remote(interface=any)
124        return Any()
125    def decref(clid=int, count=int):
126        """Release some references to my-reference 'clid'. I will return an
127        ack when the operation has completed."""
128        return None
129    def decgift(giftID=int, count=int):
130        """Release some reference to a their-reference 'giftID' that was
131        sent earlier."""
132        return None
133
134
135@implementer(RIBroker, IBroker)
136class Broker(banana.Banana, referenceable.Referenceable):
137    """I manage a connection to a remote Broker.
138
139    @ivar tub: the L{Tub} which contains us
140    @ivar yourReferenceByCLID: maps your CLID to a RemoteReferenceData
141    #@ivar yourReferenceByName: maps a per-Tub name to a RemoteReferenceData
142    @ivar yourReferenceByURL: maps a global URL to a RemoteReferenceData
143
144    """
145
146    slicerClass = PBRootSlicer
147    unslicerClass = PBRootUnslicer
148    unsafeTracebacks = True
149    requireSchema = False
150    disconnected = False
151    factory = None
152    tub = None
153    remote_broker = None
154    startingTLS = False
155    startedTLS = False
156    use_remote_broker = True
157
158    def __init__(self, remote_tubref, params={},
159                 keepaliveTimeout=None, disconnectTimeout=None,
160                 connectionInfo=None):
161        banana.Banana.__init__(self, params)
162        self._expose_remote_exception_types = True
163        self.remote_tubref = remote_tubref
164        self.keepaliveTimeout = keepaliveTimeout
165        self.disconnectTimeout = disconnectTimeout
166        self._banana_decision_version = params.get("banana-decision-version") # native str
167        vocab_table_index = params.get('initial-vocab-table-index') # native str
168        if vocab_table_index:
169            table = vocab.INITIAL_VOCAB_TABLES[vocab_table_index]
170            self.populateVocabTable(table)
171        self.initBroker()
172        self.current_slave_IR = params.get('current-slave-IR')
173        self.current_seqnum = params.get('current-seqnum')
174        self.creation_timestamp = time.time()
175        self._connectionInfo = connectionInfo
176
177    def initBroker(self):
178
179        # tracking Referenceables
180        # sending side uses these
181        self.nextCLID = count(1) # 0 is for the broker
182        self.myReferenceByPUID = {} # maps ref.processUniqueID to a tracker
183        self.myReferenceByCLID = {} # maps CLID to a tracker
184        # receiving side uses these
185        self.yourReferenceByCLID = {}
186        self.yourReferenceByURL = {}
187
188        # tracking Gifts
189        self.nextGiftID = count(1)
190        self.myGifts = {} # maps (broker,clid) to (rref, giftID, count)
191        self.myGiftsByGiftID = {} # maps giftID to (broker,clid)
192
193        # remote calls
194        # sending side uses these
195        self.nextReqID = count(1) # 0 means "we don't want a response"
196        self.waitingForAnswers = {} # we wait for the other side to answer
197        self.disconnectWatchers = []
198
199        # Callables waiting to hear about connectionLost.
200        self._connectionLostWatchers = []
201
202        # receiving side uses these
203        self.inboundDeliveryQueue = []
204        self._waiting_for_call_to_be_ready = False
205        self.activeLocalCalls = {} # the other side wants an answer from us
206
207    def setTub(self, tub):
208        assert ipb.ITub.providedBy(tub)
209        self.tub = tub
210        self.unsafeTracebacks = tub.unsafeTracebacks
211        self._expose_remote_exception_types = tub._expose_remote_exception_types
212        if tub.debugBanana:
213            self.debugSend = True
214            self.debugReceive = True
215
216    def connectionMade(self):
217        banana.Banana.connectionMade(self)
218        self.rootSlicer.broker = self
219        self.rootUnslicer.broker = self
220        if self.use_remote_broker:
221            self._create_remote_broker()
222
223    def _create_remote_broker(self):
224        # create the remote_broker object. We don't use the usual
225        # reference-counting mechanism here, because this is a synthetic
226        # object that lives forever.
227        tracker = referenceable.RemoteReferenceTracker(self, 0, None,
228                                                       "RIBroker")
229        self.remote_broker = referenceable.RemoteReference(tracker)
230
231    # connectionTimedOut is called in response to the Banana layer detecting
232    # the lack of connection activity
233
234    def connectionTimedOut(self):
235        err = error.ConnectionLost("banana timeout: connection dropped")
236        why = failure.Failure(err)
237        self.shutdown(why)
238
239    def shutdown(self, why, fireDisconnectWatchers=True):
240        """Stop using this connection. If fireDisconnectWatchers is False,
241        all disconnect watchers are removed before shutdown, so they will not
242        be called (this is appropriate when the Broker is shutting down
243        because the whole Tub is being shut down). We terminate the
244        connection quickly, rather than waiting for the transmit queue to
245        drain.
246        """
247        assert isinstance(why, failure.Failure)
248        if not fireDisconnectWatchers:
249            self.disconnectWatchers = []
250        self.finish(why)
251        # loseConnection eventually provokes connectionLost()
252        self.transport.loseConnection()
253
254    def connectionLost(self, why):
255        tubid = "?"
256        if self.remote_tubref:
257            tubid = self.remote_tubref.getShortTubID()
258        log.msg("connection to %s lost" % tubid, facility="foolscap.connection")
259        banana.Banana.connectionLost(self, why)
260        self.finish(why)
261        self._notifyConnectionLostWatchers()
262
263    def _notifyConnectionLostWatchers(self):
264        """
265        Call all functions waiting to learn about the loss of the connection of
266        this broker.
267        """
268        watchers = self._connectionLostWatchers
269        self._connectionLostWatchers = None
270
271        for w in watchers:
272            eventually(w)
273
274    def finish(self, why):
275        if self.disconnected:
276            return
277        assert isinstance(why, failure.Failure), why
278        self.disconnected = True
279        self.remote_broker = None
280        self.abandonAllRequests(why)
281        # TODO: why reset all the tables to something useable? There may be
282        # outstanding RemoteReferences that point to us, but I don't see why
283        # that requires all these empty dictionaries.
284        self.myReferenceByPUID = {}
285        self.myReferenceByCLID = {}
286        self.yourReferenceByCLID = {}
287        self.yourReferenceByURL = {}
288        self.myGifts = {}
289        self.myGiftsByGiftID = {}
290        for (cb,args,kwargs) in self.disconnectWatchers:
291            eventually(cb, *args, **kwargs)
292        self.disconnectWatchers = []
293        if self.tub:
294            # TODO: remove the conditional. It is only here to accomodate
295            # some tests: test_pb.TestCall.testDisconnect[123]
296            self.tub.brokerDetached(self, why)
297
298    def _notifyOnConnectionLost(self, callback):
299        """
300        Arrange to have C{callback} called when this broker loses its connection.
301        """
302        self._connectionLostWatchers.append(callback)
303
304    def notifyOnDisconnect(self, callback, *args, **kwargs):
305        marker = (callback, args, kwargs)
306        if self.disconnected:
307            eventually(callback, *args, **kwargs)
308        else:
309            self.disconnectWatchers.append(marker)
310        return marker
311    def dontNotifyOnDisconnect(self, marker):
312        if self.disconnected:
313            return
314        # be tolerant of attempts to unregister a callback that has already
315        # fired. I think it is hard to write safe code without this
316        # tolerance.
317
318        # TODO: on the other hand, I'm not sure this is the best policy,
319        # since you lose the feedback that tells you about
320        # unregistering-the-wrong-thing bugs. We need to look at the way that
321        # register/unregister gets used and see if there is a way to retain
322        # the typechecking that results from insisting that you can only
323        # remove something that was stil in the list.
324        if marker in self.disconnectWatchers:
325            self.disconnectWatchers.remove(marker)
326
327    def getConnectionInfo(self):
328        return self._connectionInfo
329
330    # methods to send my Referenceables to the other side
331
332    def getTrackerForMyReference(self, puid, obj):
333        tracker = self.myReferenceByPUID.get(puid)
334        if not tracker:
335            # need to add one
336            clid = next(self.nextCLID)
337            tracker = referenceable.ReferenceableTracker(self.tub,
338                                                         obj, puid, clid)
339            self.myReferenceByPUID[puid] = tracker
340            self.myReferenceByCLID[clid] = tracker
341        return tracker
342
343    def getTrackerForMyCall(self, puid, obj):
344        # just like getTrackerForMyReference, but with a negative clid
345        tracker = self.myReferenceByPUID.get(puid)
346        if not tracker:
347            # need to add one
348            clid = next(self.nextCLID)
349            clid = -clid
350            tracker = referenceable.ReferenceableTracker(self.tub,
351                                                         obj, puid, clid)
352            self.myReferenceByPUID[puid] = tracker
353            self.myReferenceByCLID[clid] = tracker
354        return tracker
355
356    # methods to handle inbound 'my-reference' sequences
357
358    def getTrackerForYourReference(self, clid, interfaceName=None, url=None):
359        """The far end holds a Referenceable and has just sent us a reference
360        to it (expressed as a small integer). If this is a new reference,
361        they will give us an interface name too, and possibly a global URL
362        for it. Obtain a RemoteReference object (creating it if necessary) to
363        give to the local recipient.
364
365        The sender remembers that we hold a reference to their object. When
366        our RemoteReference goes away, we send a decref message to them, so
367        they can possibly free their object. """
368
369        assert type(interfaceName) is str or interfaceName is None
370        if url is not None:
371            assert type(url) is str
372        tracker = self.yourReferenceByCLID.get(clid)
373        if not tracker:
374            # TODO: translate interfaceNames to RemoteInterfaces
375            if clid >= 0:
376                trackerclass = referenceable.RemoteReferenceTracker
377            else:
378                trackerclass = referenceable.RemoteMethodReferenceTracker
379            tracker = trackerclass(self, clid, url, interfaceName)
380            self.yourReferenceByCLID[clid] = tracker
381            if url:
382                self.yourReferenceByURL[url] = tracker
383        return tracker
384
385    def freeYourReference(self, tracker, count):
386        # this is called when the RemoteReference is deleted
387        if not self.remote_broker: # tests do not set this up
388            self.freeYourReferenceTracker(None, tracker)
389            return
390        try:
391            rb = self.remote_broker
392            # TODO: do we want callRemoteOnly here? is there a way we can
393            # avoid wanting to know when the decref has completed? Only if we
394            # send the interface list and URL on every occurrence of the
395            # my-reference sequence. Either A) we use callRemote("decref")
396            # and wait until the ack to free the tracker, or B) we use
397            # callRemoteOnly("decref") and free the tracker right away. In
398            # case B, the far end has no way to know that we've just freed
399            # the tracker and will therefore forget about everything they
400            # told us (including the interface list), so they cannot
401            # accurately do anything special on the "first" send of this
402            # reference. Which means that if we do B, we must either send
403            # that extra information on every my-reference sequence, or do
404            # without it, or make it optional, or retrieve it separately, or
405            # something.
406
407            # rb.callRemoteOnly("decref", clid=tracker.clid, count=count)
408            # self.freeYourReferenceTracker('bogus', tracker)
409            # return
410
411            d = rb.callRemote("decref", clid=tracker.clid, count=count)
412            # if the connection was lost before we can get an ack, we're
413            # tearing this down anyway
414            def _ignore_loss(f):
415                f.trap(DeadReferenceError, *LOST_CONNECTION_ERRORS)
416                return None
417            d.addErrback(_ignore_loss)
418            # once the ack comes back, or if we know we'll never get one,
419            # release the tracker
420            d.addCallback(self.freeYourReferenceTracker, tracker)
421        except:
422            f = failure.Failure()
423            log.msg("failure during freeRemoteReference", facility="foolscap",
424                    level=log.UNUSUAL, failure=f)
425
426    def freeYourReferenceTracker(self, res, tracker):
427        if tracker.received_count != 0:
428            return
429        if tracker.clid in self.yourReferenceByCLID:
430            del self.yourReferenceByCLID[tracker.clid]
431        if tracker.url and tracker.url in self.yourReferenceByURL:
432            del self.yourReferenceByURL[tracker.url]
433
434
435    # methods to handle inbound 'your-reference' sequences
436
437    def getMyReferenceByCLID(self, clid):
438        """clid is the connection-local ID of the Referenceable the other
439        end is trying to invoke or point to. If it is a number, they want an
440        implicitly-created per-connection object that we sent to them at
441        some point in the past. If it is a string, they want an object that
442        was registered with our Factory.
443        """
444
445        assert isinstance(clid, six.integer_types)
446        if clid == 0:
447            return self
448        return self.myReferenceByCLID[clid].obj
449        # obj = IReferenceable(obj)
450        # assert isinstance(obj, pb.Referenceable)
451        # obj needs .getMethodSchema, which needs .getArgConstraint
452
453    def remote_decref(self, clid, count):
454        # invoked when the other side sends us a decref message
455        assert isinstance(clid, six.integer_types)
456        assert clid != 0
457        tracker = self.myReferenceByCLID.get(clid, None)
458        if not tracker:
459            return # already gone, probably because we're shutting down
460        done = tracker.decref(count)
461        if done:
462            del self.myReferenceByPUID[tracker.puid]
463            del self.myReferenceByCLID[clid]
464
465    # methods to send RemoteReference 'gifts' to third-parties
466
467    def makeGift(self, rref):
468        # return the giftid
469        broker, clid = rref.tracker.broker, rref.tracker.clid
470        i = (broker, clid)
471        old = self.myGifts.get(i)
472        if old:
473            rref, giftID, count = old
474            self.myGifts[i] = (rref, giftID, count+1)
475        else:
476            giftID = next(self.nextGiftID)
477            self.myGiftsByGiftID[giftID] = i
478            self.myGifts[i] = (rref, giftID, 1)
479        return giftID
480
481    def remote_decgift(self, giftID, count):
482        broker, clid = self.myGiftsByGiftID[giftID]
483        rref, giftID, gift_count = self.myGifts[(broker, clid)]
484        gift_count -= count
485        if gift_count == 0:
486            del self.myGiftsByGiftID[giftID]
487            del self.myGifts[(broker, clid)]
488        else:
489            self.myGifts[(broker, clid)] = (rref, giftID, gift_count)
490
491    # methods to deal with URLs
492
493    def getYourReferenceByName(self, name):
494        # remain compatible with remotes running py2
495        name = six.ensure_binary(name)
496        d = self.remote_broker.callRemote("getReferenceByName", name=name)
497        return d
498
499    def remote_getReferenceByName(self, name):
500        return self.tub.getReferenceForName(six.ensure_str(name))
501
502    # remote-method-invocation methods, calling side, invoked by
503    # RemoteReference.callRemote and CallSlicer
504
505    def newRequestID(self):
506        if self.disconnected:
507            raise DeadReferenceError("Calling Stale Broker")
508        return next(self.nextReqID)
509
510    def addRequest(self, req):
511        req.broker = self
512        self.waitingForAnswers[req.reqID] = req
513
514    def removeRequest(self, req):
515        del self.waitingForAnswers[req.reqID]
516
517    def getRequest(self, reqID):
518        # invoked by AnswerUnslicer and ErrorUnslicer
519        try:
520            return self.waitingForAnswers[reqID]
521        except KeyError:
522            raise Violation("non-existent reqID '%d'" % reqID)
523
524    def abandonAllRequests(self, why):
525        for req in list(self.waitingForAnswers.values()):
526            if why.check(*LOST_CONNECTION_ERRORS):
527                # map all connection-lost errors to DeadReferenceError, so
528                # application code only needs to check for one exception type
529                tubid = None
530                # since we're creating a new exception object for each call,
531                # let's add more information to it
532                if self.remote_tubref:
533                    tubid = self.remote_tubref.getShortTubID()
534                e = DeadReferenceError("Connection was lost", tubid, req)
535                why = failure.Failure(e)
536            eventually(req.fail, why)
537
538    # target-side, invoked by CallUnslicer
539
540    def getRemoteInterfaceByName(self, riname):
541        # this lives in the broker because it ought to be per-connection
542        return remoteinterface.RemoteInterfaceRegistry[riname]
543
544    def getSchemaForMethod(self, rifaces, methodname):
545        # this lives in the Broker so it can override the resolution order,
546        # not that overlapping RemoteInterfaces should be allowed to happen
547        # all that often
548        for ri in rifaces:
549            m = ri.get(methodname)
550            if m:
551                return m
552        return None
553
554    def scheduleCall(self, delivery, ready_deferred):
555        self.inboundDeliveryQueue.append( (delivery,ready_deferred) )
556        eventually(self.doNextCall)
557
558    def doNextCall(self):
559        if self.disconnected:
560            return
561        if self._waiting_for_call_to_be_ready:
562            return
563        if not self.inboundDeliveryQueue:
564            return
565        delivery, ready_deferred = self.inboundDeliveryQueue.pop(0)
566        self._waiting_for_call_to_be_ready = True
567        if not ready_deferred:
568            ready_deferred = defer.succeed(None)
569        d = ready_deferred
570
571        def _ready(res):
572            self._waiting_for_call_to_be_ready = False
573            eventually(self.doNextCall)
574            return res
575        d.addBoth(_ready)
576
577        # at this point, the Deferred chain for this one delivery runs
578        # independently of any other, and methods which take a long time to
579        # complete will not hold up other methods. We must call _doCall and
580        # let the remote_ method get control before we process any other
581        # message, but the eventually() above insures we'll have a chance to
582        # do that before we give up control.
583
584        d.addCallback(lambda res: self._doCall(delivery))
585        d.addCallback(self._callFinished, delivery)
586        d.addErrback(self.callFailed, delivery.reqID, delivery)
587        d.addErrback(log.err)
588        return None
589
590    def _doCall(self, delivery):
591        # our ordering rules require that the order in which each
592        # remote_foo() method gets control is exactly the same as the order
593        # in which the original caller invoked callRemote(). To insure this,
594        # _startCall() is not allowed to insert additional delays before it
595        # runs doRemoteCall() on the target object.
596        obj = delivery.obj
597        args = delivery.allargs.args
598        kwargs = delivery.allargs.kwargs
599        for i in args + list(kwargs.values()):
600            assert not isinstance(i, defer.Deferred)
601
602        if delivery.methodSchema:
603            # we asked about each argument on the way in, but ask again so
604            # they can look for missing arguments. TODO: see if we can remove
605            # the redundant per-argument checks.
606            delivery.methodSchema.checkAllArgs(args, kwargs, True)
607
608        # interesting case: if the method completes successfully, but
609        # our schema prohibits us from sending the result (perhaps the
610        # method returned an int but the schema insists upon a string).
611        # TODO: move the return-value schema check into
612        # Referenceable.doRemoteCall, so the exception's traceback will be
613        # attached to the object that caused it
614        if delivery.methodname is None:
615            assert callable(obj)
616            return obj(*args, **kwargs)
617        else:
618            obj = ipb.IRemotelyCallable(obj)
619            return obj.doRemoteCall(delivery.methodname, args, kwargs)
620
621
622    def _callFinished(self, res, delivery):
623        reqID = delivery.reqID
624        if reqID == 0:
625            return
626        methodSchema = delivery.methodSchema
627        assert self.activeLocalCalls[reqID]
628        methodName = None
629        if methodSchema:
630            methodName = methodSchema.name
631            try:
632                methodSchema.checkResults(res, False) # may raise Violation
633            except Violation as v:
634                v.prependLocation("in return value of %s.%s" %
635                                  (delivery.obj, methodSchema.name))
636                raise
637
638        answer = call.AnswerSlicer(reqID, res, methodName)
639        # once the answer has started transmitting, any exceptions must be
640        # logged and dropped, and not turned into an Error to be sent.
641        try:
642            self.send(answer)
643            # TODO: .send should return a Deferred that fires when the last
644            # byte has been queued, and we should delete the local note then
645        except:
646            f = failure.Failure()
647            log.msg("Broker._callfinished unable to send",
648                    facility="foolscap", level=log.UNUSUAL, failure=f)
649        del self.activeLocalCalls[reqID]
650
651    def callFailed(self, f, reqID, delivery=None):
652        # this may be called either when an inbound schema is violated, or
653        # when the method is run and raises an exception. If a Violation is
654        # raised after we receive the reqID but before we've actually invoked
655        # the method, we are called by CallUnslicer.reportViolation and don't
656        # get a delivery= argument.
657        if delivery:
658            if (self.tub and self.tub.logLocalFailures) or not self.tub:
659                # the 'not self.tub' case is for unit tests
660                delivery.logFailure(f)
661        if reqID != 0:
662            assert self.activeLocalCalls[reqID]
663            self.send(call.ErrorSlicer(reqID, f))
664            del self.activeLocalCalls[reqID]
665
666class StorageBrokerRootSlicer(ScopedRootSlicer):
667    # each StorageBroker is a single serialization domain, so we inherit from
668    # ScopedRootSlicer
669    slicerTable = {types.MethodType: referenceable.CallableSlicer,
670                   types.FunctionType: referenceable.CallableSlicer,
671                   }
672
673PBStorageOpenRegistry = {
674    ('their-reference',): referenceable.TheirReferenceUnslicer,
675    }
676
677class StorageBrokerRootUnslicer(PBRootUnslicer):
678    # we want all the behavior of PBRootUnslicer, plus the scopedness of a
679    # ScopedRootUnslicer. TODO: find some way to refactor all of this,
680    # probably by making the scopedness a mixin.
681
682    openRegistries = [slicer.UnslicerRegistry, PBStorageOpenRegistry]
683    topRegistries = openRegistries
684
685    def __init__(self, protocol):
686        PBRootUnslicer.__init__(self, protocol)
687        self.references = {}
688
689    def setObject(self, counter, obj):
690        self.references[counter] = obj
691
692    def getObject(self, counter):
693        obj = self.references.get(counter)
694        return obj
695
696    def receiveChild(self, obj, ready_deferred):
697        self.protocol.receiveChild(obj, ready_deferred)
698
699    def reportViolation(self, why):
700        # unlike PBRootUnslicer, we do *not* absorb the failure. Any error
701        # during deserialization is fatal to the process. We give it to the
702        # StorageBroker, which will eventually fire the unserialization
703        # Deferred.
704        self.protocol.reportViolation(why)
705
706class StorageBroker(Broker):
707    # like Broker, but used to serialize data for storage rather than for
708    # transmission over a specific connection.
709    slicerClass = StorageBrokerRootSlicer
710    unslicerClass = StorageBrokerRootUnslicer
711    object = None
712    violation = None
713    disconnectReason = None
714    use_remote_broker = False
715
716    def prepare(self):
717        self.d = defer.Deferred()
718        return self.d
719
720    def receiveChild(self, obj, ready_deferred):
721        if ready_deferred:
722            ready_deferred.addBoth(self.d.callback)
723            self.d.addCallback(lambda res: obj)
724        else:
725            self.d.callback(obj)
726        del self.d
727
728    def reportViolation(self, why):
729        self.violation = why
730        eventually(self.d.callback, None)
731        return None
732
733    def reportReceiveError(self, f):
734        self.disconnectReason = f
735        f.raiseException()
736
737# this loopback stuff is based upon twisted.protocols.loopback, except that
738# we use it for real, not just for testing. The IConsumer stuff hasn't been
739# tested at all.
740
741@implementer(twinterfaces.IAddress)
742class LoopbackAddress(object):
743    pass
744
745@implementer(twinterfaces.ITransport, twinterfaces.IConsumer)
746class LoopbackTransport(object):
747    # we always create these in pairs, with .peer pointing at each other
748
749    producer = None
750
751    def __init__(self):
752        self.connected = True
753    def setPeer(self, peer):
754        self.peer = peer
755
756    def write(self, bytes):
757        eventually(self.peer.dataReceived, bytes)
758    def writeSequence(self, iovec):
759        self.write(''.join(iovec))
760
761    def dataReceived(self, data):
762        if self.connected:
763            self.protocol.dataReceived(data)
764
765    def loseConnection(self, _connDone=connectionDone):
766        if not self.connected:
767            return
768        self.connected = False
769        eventually(self.peer.connectionLost, _connDone)
770        eventually(self.protocol.connectionLost, _connDone)
771    def connectionLost(self, reason):
772        if not self.connected:
773            return
774        self.connected = False
775        self.protocol.connectionLost(reason)
776
777    def getPeer(self):
778        return LoopbackAddress()
779    def getHost(self):
780        return LoopbackAddress()
781
782    # IConsumer
783    def registerProducer(self, producer, streaming):
784        assert self.producer is None
785        self.producer = producer
786        self.streamingProducer = streaming
787        self._pollProducer()
788
789    def unregisterProducer(self):
790        assert self.producer is not None
791        self.producer = None
792
793    def _pollProducer(self):
794        if self.producer is not None and not self.streamingProducer:
795            self.producer.resumeProducing()
796