1 2 3# This module is responsible for the per-connection Broker object 4 5from __future__ import print_function 6import six 7import types, time 8from itertools import count 9 10from zope.interface import implementer 11from twisted.python import failure 12from twisted.internet import defer, error 13from twisted.internet import interfaces as twinterfaces 14from twisted.internet.protocol import connectionDone 15 16from foolscap import banana, tokens, ipb, vocab 17from foolscap import call, slicer, referenceable, copyable, remoteinterface 18from foolscap.constraint import Any 19from foolscap.tokens import Violation, BananaError 20from foolscap.ipb import DeadReferenceError, IBroker 21from foolscap.slicers.root import RootSlicer, RootUnslicer, ScopedRootSlicer 22from foolscap.eventual import eventually 23from foolscap.logging import log 24from functools import reduce 25 26LOST_CONNECTION_ERRORS = [error.ConnectionLost, error.ConnectionDone] 27try: 28 from OpenSSL import SSL 29 LOST_CONNECTION_ERRORS.append(SSL.Error) 30except ImportError: 31 pass 32 33PBTopRegistry = { 34 ("call",): call.CallUnslicer, 35 ("answer",): call.AnswerUnslicer, 36 ("error",): call.ErrorUnslicer, 37 } 38 39PBOpenRegistry = { 40 ('arguments',): call.ArgumentUnslicer, 41 ('my-reference',): referenceable.ReferenceUnslicer, 42 ('your-reference',): referenceable.YourReferenceUnslicer, 43 ('their-reference',): referenceable.TheirReferenceUnslicer, 44 # ('copyable', classname) is handled inline, through the CopyableRegistry 45 } 46 47class PBRootUnslicer(RootUnslicer): 48 # topRegistries defines what objects are allowed at the top-level 49 topRegistries = [PBTopRegistry] 50 # openRegistries defines what objects are allowed at the second level and 51 # below 52 openRegistries = [slicer.UnslicerRegistry, PBOpenRegistry] 53 logViolations = False 54 55 def checkToken(self, typebyte, size): 56 if typebyte != tokens.OPEN: 57 raise BananaError("top-level must be OPEN") 58 59 def openerCheckToken(self, typebyte, size, opentype): 60 if typebyte == tokens.STRING: 61 if len(opentype) == 0: 62 if size > self.maxIndexLength: 63 why = "first opentype STRING token is too long, %d>%d" % \ 64 (size, self.maxIndexLength) 65 raise Violation(why) 66 if opentype == ("copyable",): 67 # TODO: this is silly, of course (should pre-compute maxlen) 68 maxlen = reduce(max, 69 [len(cname) \ 70 for cname in list(copyable.CopyableRegistry.keys())] 71 ) 72 if size > maxlen: 73 why = "copyable-classname token is too long, %d>%d" % \ 74 (size, maxlen) 75 raise Violation(why) 76 elif typebyte == tokens.VOCAB: 77 return 78 else: 79 # TODO: hack for testing 80 raise Violation("index token 0x%02x not STRING or VOCAB" % \ 81 six.byte2int(typebyte)) 82 raise BananaError("index token 0x%02x not STRING or VOCAB" % \ 83 six.byte2int(typebyte)) 84 85 def open(self, opentype): 86 # used for lower-level objects, delegated up from childunslicer.open 87 child = RootUnslicer.open(self, opentype) 88 if child: 89 child.broker = self.broker 90 return child 91 92 def doOpen(self, opentype): 93 child = RootUnslicer.doOpen(self, opentype) 94 if child: 95 child.broker = self.broker 96 return child 97 98 def reportViolation(self, f): 99 if self.logViolations: 100 print("hey, something failed:", f) 101 return None # absorb the failure 102 103 def receiveChild(self, token, ready_deferred): 104 if isinstance(token, call.InboundDelivery): 105 self.broker.scheduleCall(token, ready_deferred) 106 107 108 109class PBRootSlicer(RootSlicer): 110 slicerTable = {types.MethodType: referenceable.CallableSlicer, 111 types.FunctionType: referenceable.CallableSlicer, 112 } 113 def registerRefID(self, refid, obj): 114 # references are never Broker-scoped: they're always scoped more 115 # narrowly, by the CallSlicer or the AnswerSlicer. 116 assert 0 117 118 119class RIBroker(remoteinterface.RemoteInterface): 120 def getReferenceByName(name=bytes): 121 """If I have published an object by that name, return a reference to 122 it.""" 123 # return Remote(interface=any) 124 return Any() 125 def decref(clid=int, count=int): 126 """Release some references to my-reference 'clid'. I will return an 127 ack when the operation has completed.""" 128 return None 129 def decgift(giftID=int, count=int): 130 """Release some reference to a their-reference 'giftID' that was 131 sent earlier.""" 132 return None 133 134 135@implementer(RIBroker, IBroker) 136class Broker(banana.Banana, referenceable.Referenceable): 137 """I manage a connection to a remote Broker. 138 139 @ivar tub: the L{Tub} which contains us 140 @ivar yourReferenceByCLID: maps your CLID to a RemoteReferenceData 141 #@ivar yourReferenceByName: maps a per-Tub name to a RemoteReferenceData 142 @ivar yourReferenceByURL: maps a global URL to a RemoteReferenceData 143 144 """ 145 146 slicerClass = PBRootSlicer 147 unslicerClass = PBRootUnslicer 148 unsafeTracebacks = True 149 requireSchema = False 150 disconnected = False 151 factory = None 152 tub = None 153 remote_broker = None 154 startingTLS = False 155 startedTLS = False 156 use_remote_broker = True 157 158 def __init__(self, remote_tubref, params={}, 159 keepaliveTimeout=None, disconnectTimeout=None, 160 connectionInfo=None): 161 banana.Banana.__init__(self, params) 162 self._expose_remote_exception_types = True 163 self.remote_tubref = remote_tubref 164 self.keepaliveTimeout = keepaliveTimeout 165 self.disconnectTimeout = disconnectTimeout 166 self._banana_decision_version = params.get("banana-decision-version") # native str 167 vocab_table_index = params.get('initial-vocab-table-index') # native str 168 if vocab_table_index: 169 table = vocab.INITIAL_VOCAB_TABLES[vocab_table_index] 170 self.populateVocabTable(table) 171 self.initBroker() 172 self.current_slave_IR = params.get('current-slave-IR') 173 self.current_seqnum = params.get('current-seqnum') 174 self.creation_timestamp = time.time() 175 self._connectionInfo = connectionInfo 176 177 def initBroker(self): 178 179 # tracking Referenceables 180 # sending side uses these 181 self.nextCLID = count(1) # 0 is for the broker 182 self.myReferenceByPUID = {} # maps ref.processUniqueID to a tracker 183 self.myReferenceByCLID = {} # maps CLID to a tracker 184 # receiving side uses these 185 self.yourReferenceByCLID = {} 186 self.yourReferenceByURL = {} 187 188 # tracking Gifts 189 self.nextGiftID = count(1) 190 self.myGifts = {} # maps (broker,clid) to (rref, giftID, count) 191 self.myGiftsByGiftID = {} # maps giftID to (broker,clid) 192 193 # remote calls 194 # sending side uses these 195 self.nextReqID = count(1) # 0 means "we don't want a response" 196 self.waitingForAnswers = {} # we wait for the other side to answer 197 self.disconnectWatchers = [] 198 199 # Callables waiting to hear about connectionLost. 200 self._connectionLostWatchers = [] 201 202 # receiving side uses these 203 self.inboundDeliveryQueue = [] 204 self._waiting_for_call_to_be_ready = False 205 self.activeLocalCalls = {} # the other side wants an answer from us 206 207 def setTub(self, tub): 208 assert ipb.ITub.providedBy(tub) 209 self.tub = tub 210 self.unsafeTracebacks = tub.unsafeTracebacks 211 self._expose_remote_exception_types = tub._expose_remote_exception_types 212 if tub.debugBanana: 213 self.debugSend = True 214 self.debugReceive = True 215 216 def connectionMade(self): 217 banana.Banana.connectionMade(self) 218 self.rootSlicer.broker = self 219 self.rootUnslicer.broker = self 220 if self.use_remote_broker: 221 self._create_remote_broker() 222 223 def _create_remote_broker(self): 224 # create the remote_broker object. We don't use the usual 225 # reference-counting mechanism here, because this is a synthetic 226 # object that lives forever. 227 tracker = referenceable.RemoteReferenceTracker(self, 0, None, 228 "RIBroker") 229 self.remote_broker = referenceable.RemoteReference(tracker) 230 231 # connectionTimedOut is called in response to the Banana layer detecting 232 # the lack of connection activity 233 234 def connectionTimedOut(self): 235 err = error.ConnectionLost("banana timeout: connection dropped") 236 why = failure.Failure(err) 237 self.shutdown(why) 238 239 def shutdown(self, why, fireDisconnectWatchers=True): 240 """Stop using this connection. If fireDisconnectWatchers is False, 241 all disconnect watchers are removed before shutdown, so they will not 242 be called (this is appropriate when the Broker is shutting down 243 because the whole Tub is being shut down). We terminate the 244 connection quickly, rather than waiting for the transmit queue to 245 drain. 246 """ 247 assert isinstance(why, failure.Failure) 248 if not fireDisconnectWatchers: 249 self.disconnectWatchers = [] 250 self.finish(why) 251 # loseConnection eventually provokes connectionLost() 252 self.transport.loseConnection() 253 254 def connectionLost(self, why): 255 tubid = "?" 256 if self.remote_tubref: 257 tubid = self.remote_tubref.getShortTubID() 258 log.msg("connection to %s lost" % tubid, facility="foolscap.connection") 259 banana.Banana.connectionLost(self, why) 260 self.finish(why) 261 self._notifyConnectionLostWatchers() 262 263 def _notifyConnectionLostWatchers(self): 264 """ 265 Call all functions waiting to learn about the loss of the connection of 266 this broker. 267 """ 268 watchers = self._connectionLostWatchers 269 self._connectionLostWatchers = None 270 271 for w in watchers: 272 eventually(w) 273 274 def finish(self, why): 275 if self.disconnected: 276 return 277 assert isinstance(why, failure.Failure), why 278 self.disconnected = True 279 self.remote_broker = None 280 self.abandonAllRequests(why) 281 # TODO: why reset all the tables to something useable? There may be 282 # outstanding RemoteReferences that point to us, but I don't see why 283 # that requires all these empty dictionaries. 284 self.myReferenceByPUID = {} 285 self.myReferenceByCLID = {} 286 self.yourReferenceByCLID = {} 287 self.yourReferenceByURL = {} 288 self.myGifts = {} 289 self.myGiftsByGiftID = {} 290 for (cb,args,kwargs) in self.disconnectWatchers: 291 eventually(cb, *args, **kwargs) 292 self.disconnectWatchers = [] 293 if self.tub: 294 # TODO: remove the conditional. It is only here to accomodate 295 # some tests: test_pb.TestCall.testDisconnect[123] 296 self.tub.brokerDetached(self, why) 297 298 def _notifyOnConnectionLost(self, callback): 299 """ 300 Arrange to have C{callback} called when this broker loses its connection. 301 """ 302 self._connectionLostWatchers.append(callback) 303 304 def notifyOnDisconnect(self, callback, *args, **kwargs): 305 marker = (callback, args, kwargs) 306 if self.disconnected: 307 eventually(callback, *args, **kwargs) 308 else: 309 self.disconnectWatchers.append(marker) 310 return marker 311 def dontNotifyOnDisconnect(self, marker): 312 if self.disconnected: 313 return 314 # be tolerant of attempts to unregister a callback that has already 315 # fired. I think it is hard to write safe code without this 316 # tolerance. 317 318 # TODO: on the other hand, I'm not sure this is the best policy, 319 # since you lose the feedback that tells you about 320 # unregistering-the-wrong-thing bugs. We need to look at the way that 321 # register/unregister gets used and see if there is a way to retain 322 # the typechecking that results from insisting that you can only 323 # remove something that was stil in the list. 324 if marker in self.disconnectWatchers: 325 self.disconnectWatchers.remove(marker) 326 327 def getConnectionInfo(self): 328 return self._connectionInfo 329 330 # methods to send my Referenceables to the other side 331 332 def getTrackerForMyReference(self, puid, obj): 333 tracker = self.myReferenceByPUID.get(puid) 334 if not tracker: 335 # need to add one 336 clid = next(self.nextCLID) 337 tracker = referenceable.ReferenceableTracker(self.tub, 338 obj, puid, clid) 339 self.myReferenceByPUID[puid] = tracker 340 self.myReferenceByCLID[clid] = tracker 341 return tracker 342 343 def getTrackerForMyCall(self, puid, obj): 344 # just like getTrackerForMyReference, but with a negative clid 345 tracker = self.myReferenceByPUID.get(puid) 346 if not tracker: 347 # need to add one 348 clid = next(self.nextCLID) 349 clid = -clid 350 tracker = referenceable.ReferenceableTracker(self.tub, 351 obj, puid, clid) 352 self.myReferenceByPUID[puid] = tracker 353 self.myReferenceByCLID[clid] = tracker 354 return tracker 355 356 # methods to handle inbound 'my-reference' sequences 357 358 def getTrackerForYourReference(self, clid, interfaceName=None, url=None): 359 """The far end holds a Referenceable and has just sent us a reference 360 to it (expressed as a small integer). If this is a new reference, 361 they will give us an interface name too, and possibly a global URL 362 for it. Obtain a RemoteReference object (creating it if necessary) to 363 give to the local recipient. 364 365 The sender remembers that we hold a reference to their object. When 366 our RemoteReference goes away, we send a decref message to them, so 367 they can possibly free their object. """ 368 369 assert type(interfaceName) is str or interfaceName is None 370 if url is not None: 371 assert type(url) is str 372 tracker = self.yourReferenceByCLID.get(clid) 373 if not tracker: 374 # TODO: translate interfaceNames to RemoteInterfaces 375 if clid >= 0: 376 trackerclass = referenceable.RemoteReferenceTracker 377 else: 378 trackerclass = referenceable.RemoteMethodReferenceTracker 379 tracker = trackerclass(self, clid, url, interfaceName) 380 self.yourReferenceByCLID[clid] = tracker 381 if url: 382 self.yourReferenceByURL[url] = tracker 383 return tracker 384 385 def freeYourReference(self, tracker, count): 386 # this is called when the RemoteReference is deleted 387 if not self.remote_broker: # tests do not set this up 388 self.freeYourReferenceTracker(None, tracker) 389 return 390 try: 391 rb = self.remote_broker 392 # TODO: do we want callRemoteOnly here? is there a way we can 393 # avoid wanting to know when the decref has completed? Only if we 394 # send the interface list and URL on every occurrence of the 395 # my-reference sequence. Either A) we use callRemote("decref") 396 # and wait until the ack to free the tracker, or B) we use 397 # callRemoteOnly("decref") and free the tracker right away. In 398 # case B, the far end has no way to know that we've just freed 399 # the tracker and will therefore forget about everything they 400 # told us (including the interface list), so they cannot 401 # accurately do anything special on the "first" send of this 402 # reference. Which means that if we do B, we must either send 403 # that extra information on every my-reference sequence, or do 404 # without it, or make it optional, or retrieve it separately, or 405 # something. 406 407 # rb.callRemoteOnly("decref", clid=tracker.clid, count=count) 408 # self.freeYourReferenceTracker('bogus', tracker) 409 # return 410 411 d = rb.callRemote("decref", clid=tracker.clid, count=count) 412 # if the connection was lost before we can get an ack, we're 413 # tearing this down anyway 414 def _ignore_loss(f): 415 f.trap(DeadReferenceError, *LOST_CONNECTION_ERRORS) 416 return None 417 d.addErrback(_ignore_loss) 418 # once the ack comes back, or if we know we'll never get one, 419 # release the tracker 420 d.addCallback(self.freeYourReferenceTracker, tracker) 421 except: 422 f = failure.Failure() 423 log.msg("failure during freeRemoteReference", facility="foolscap", 424 level=log.UNUSUAL, failure=f) 425 426 def freeYourReferenceTracker(self, res, tracker): 427 if tracker.received_count != 0: 428 return 429 if tracker.clid in self.yourReferenceByCLID: 430 del self.yourReferenceByCLID[tracker.clid] 431 if tracker.url and tracker.url in self.yourReferenceByURL: 432 del self.yourReferenceByURL[tracker.url] 433 434 435 # methods to handle inbound 'your-reference' sequences 436 437 def getMyReferenceByCLID(self, clid): 438 """clid is the connection-local ID of the Referenceable the other 439 end is trying to invoke or point to. If it is a number, they want an 440 implicitly-created per-connection object that we sent to them at 441 some point in the past. If it is a string, they want an object that 442 was registered with our Factory. 443 """ 444 445 assert isinstance(clid, six.integer_types) 446 if clid == 0: 447 return self 448 return self.myReferenceByCLID[clid].obj 449 # obj = IReferenceable(obj) 450 # assert isinstance(obj, pb.Referenceable) 451 # obj needs .getMethodSchema, which needs .getArgConstraint 452 453 def remote_decref(self, clid, count): 454 # invoked when the other side sends us a decref message 455 assert isinstance(clid, six.integer_types) 456 assert clid != 0 457 tracker = self.myReferenceByCLID.get(clid, None) 458 if not tracker: 459 return # already gone, probably because we're shutting down 460 done = tracker.decref(count) 461 if done: 462 del self.myReferenceByPUID[tracker.puid] 463 del self.myReferenceByCLID[clid] 464 465 # methods to send RemoteReference 'gifts' to third-parties 466 467 def makeGift(self, rref): 468 # return the giftid 469 broker, clid = rref.tracker.broker, rref.tracker.clid 470 i = (broker, clid) 471 old = self.myGifts.get(i) 472 if old: 473 rref, giftID, count = old 474 self.myGifts[i] = (rref, giftID, count+1) 475 else: 476 giftID = next(self.nextGiftID) 477 self.myGiftsByGiftID[giftID] = i 478 self.myGifts[i] = (rref, giftID, 1) 479 return giftID 480 481 def remote_decgift(self, giftID, count): 482 broker, clid = self.myGiftsByGiftID[giftID] 483 rref, giftID, gift_count = self.myGifts[(broker, clid)] 484 gift_count -= count 485 if gift_count == 0: 486 del self.myGiftsByGiftID[giftID] 487 del self.myGifts[(broker, clid)] 488 else: 489 self.myGifts[(broker, clid)] = (rref, giftID, gift_count) 490 491 # methods to deal with URLs 492 493 def getYourReferenceByName(self, name): 494 # remain compatible with remotes running py2 495 name = six.ensure_binary(name) 496 d = self.remote_broker.callRemote("getReferenceByName", name=name) 497 return d 498 499 def remote_getReferenceByName(self, name): 500 return self.tub.getReferenceForName(six.ensure_str(name)) 501 502 # remote-method-invocation methods, calling side, invoked by 503 # RemoteReference.callRemote and CallSlicer 504 505 def newRequestID(self): 506 if self.disconnected: 507 raise DeadReferenceError("Calling Stale Broker") 508 return next(self.nextReqID) 509 510 def addRequest(self, req): 511 req.broker = self 512 self.waitingForAnswers[req.reqID] = req 513 514 def removeRequest(self, req): 515 del self.waitingForAnswers[req.reqID] 516 517 def getRequest(self, reqID): 518 # invoked by AnswerUnslicer and ErrorUnslicer 519 try: 520 return self.waitingForAnswers[reqID] 521 except KeyError: 522 raise Violation("non-existent reqID '%d'" % reqID) 523 524 def abandonAllRequests(self, why): 525 for req in list(self.waitingForAnswers.values()): 526 if why.check(*LOST_CONNECTION_ERRORS): 527 # map all connection-lost errors to DeadReferenceError, so 528 # application code only needs to check for one exception type 529 tubid = None 530 # since we're creating a new exception object for each call, 531 # let's add more information to it 532 if self.remote_tubref: 533 tubid = self.remote_tubref.getShortTubID() 534 e = DeadReferenceError("Connection was lost", tubid, req) 535 why = failure.Failure(e) 536 eventually(req.fail, why) 537 538 # target-side, invoked by CallUnslicer 539 540 def getRemoteInterfaceByName(self, riname): 541 # this lives in the broker because it ought to be per-connection 542 return remoteinterface.RemoteInterfaceRegistry[riname] 543 544 def getSchemaForMethod(self, rifaces, methodname): 545 # this lives in the Broker so it can override the resolution order, 546 # not that overlapping RemoteInterfaces should be allowed to happen 547 # all that often 548 for ri in rifaces: 549 m = ri.get(methodname) 550 if m: 551 return m 552 return None 553 554 def scheduleCall(self, delivery, ready_deferred): 555 self.inboundDeliveryQueue.append( (delivery,ready_deferred) ) 556 eventually(self.doNextCall) 557 558 def doNextCall(self): 559 if self.disconnected: 560 return 561 if self._waiting_for_call_to_be_ready: 562 return 563 if not self.inboundDeliveryQueue: 564 return 565 delivery, ready_deferred = self.inboundDeliveryQueue.pop(0) 566 self._waiting_for_call_to_be_ready = True 567 if not ready_deferred: 568 ready_deferred = defer.succeed(None) 569 d = ready_deferred 570 571 def _ready(res): 572 self._waiting_for_call_to_be_ready = False 573 eventually(self.doNextCall) 574 return res 575 d.addBoth(_ready) 576 577 # at this point, the Deferred chain for this one delivery runs 578 # independently of any other, and methods which take a long time to 579 # complete will not hold up other methods. We must call _doCall and 580 # let the remote_ method get control before we process any other 581 # message, but the eventually() above insures we'll have a chance to 582 # do that before we give up control. 583 584 d.addCallback(lambda res: self._doCall(delivery)) 585 d.addCallback(self._callFinished, delivery) 586 d.addErrback(self.callFailed, delivery.reqID, delivery) 587 d.addErrback(log.err) 588 return None 589 590 def _doCall(self, delivery): 591 # our ordering rules require that the order in which each 592 # remote_foo() method gets control is exactly the same as the order 593 # in which the original caller invoked callRemote(). To insure this, 594 # _startCall() is not allowed to insert additional delays before it 595 # runs doRemoteCall() on the target object. 596 obj = delivery.obj 597 args = delivery.allargs.args 598 kwargs = delivery.allargs.kwargs 599 for i in args + list(kwargs.values()): 600 assert not isinstance(i, defer.Deferred) 601 602 if delivery.methodSchema: 603 # we asked about each argument on the way in, but ask again so 604 # they can look for missing arguments. TODO: see if we can remove 605 # the redundant per-argument checks. 606 delivery.methodSchema.checkAllArgs(args, kwargs, True) 607 608 # interesting case: if the method completes successfully, but 609 # our schema prohibits us from sending the result (perhaps the 610 # method returned an int but the schema insists upon a string). 611 # TODO: move the return-value schema check into 612 # Referenceable.doRemoteCall, so the exception's traceback will be 613 # attached to the object that caused it 614 if delivery.methodname is None: 615 assert callable(obj) 616 return obj(*args, **kwargs) 617 else: 618 obj = ipb.IRemotelyCallable(obj) 619 return obj.doRemoteCall(delivery.methodname, args, kwargs) 620 621 622 def _callFinished(self, res, delivery): 623 reqID = delivery.reqID 624 if reqID == 0: 625 return 626 methodSchema = delivery.methodSchema 627 assert self.activeLocalCalls[reqID] 628 methodName = None 629 if methodSchema: 630 methodName = methodSchema.name 631 try: 632 methodSchema.checkResults(res, False) # may raise Violation 633 except Violation as v: 634 v.prependLocation("in return value of %s.%s" % 635 (delivery.obj, methodSchema.name)) 636 raise 637 638 answer = call.AnswerSlicer(reqID, res, methodName) 639 # once the answer has started transmitting, any exceptions must be 640 # logged and dropped, and not turned into an Error to be sent. 641 try: 642 self.send(answer) 643 # TODO: .send should return a Deferred that fires when the last 644 # byte has been queued, and we should delete the local note then 645 except: 646 f = failure.Failure() 647 log.msg("Broker._callfinished unable to send", 648 facility="foolscap", level=log.UNUSUAL, failure=f) 649 del self.activeLocalCalls[reqID] 650 651 def callFailed(self, f, reqID, delivery=None): 652 # this may be called either when an inbound schema is violated, or 653 # when the method is run and raises an exception. If a Violation is 654 # raised after we receive the reqID but before we've actually invoked 655 # the method, we are called by CallUnslicer.reportViolation and don't 656 # get a delivery= argument. 657 if delivery: 658 if (self.tub and self.tub.logLocalFailures) or not self.tub: 659 # the 'not self.tub' case is for unit tests 660 delivery.logFailure(f) 661 if reqID != 0: 662 assert self.activeLocalCalls[reqID] 663 self.send(call.ErrorSlicer(reqID, f)) 664 del self.activeLocalCalls[reqID] 665 666class StorageBrokerRootSlicer(ScopedRootSlicer): 667 # each StorageBroker is a single serialization domain, so we inherit from 668 # ScopedRootSlicer 669 slicerTable = {types.MethodType: referenceable.CallableSlicer, 670 types.FunctionType: referenceable.CallableSlicer, 671 } 672 673PBStorageOpenRegistry = { 674 ('their-reference',): referenceable.TheirReferenceUnslicer, 675 } 676 677class StorageBrokerRootUnslicer(PBRootUnslicer): 678 # we want all the behavior of PBRootUnslicer, plus the scopedness of a 679 # ScopedRootUnslicer. TODO: find some way to refactor all of this, 680 # probably by making the scopedness a mixin. 681 682 openRegistries = [slicer.UnslicerRegistry, PBStorageOpenRegistry] 683 topRegistries = openRegistries 684 685 def __init__(self, protocol): 686 PBRootUnslicer.__init__(self, protocol) 687 self.references = {} 688 689 def setObject(self, counter, obj): 690 self.references[counter] = obj 691 692 def getObject(self, counter): 693 obj = self.references.get(counter) 694 return obj 695 696 def receiveChild(self, obj, ready_deferred): 697 self.protocol.receiveChild(obj, ready_deferred) 698 699 def reportViolation(self, why): 700 # unlike PBRootUnslicer, we do *not* absorb the failure. Any error 701 # during deserialization is fatal to the process. We give it to the 702 # StorageBroker, which will eventually fire the unserialization 703 # Deferred. 704 self.protocol.reportViolation(why) 705 706class StorageBroker(Broker): 707 # like Broker, but used to serialize data for storage rather than for 708 # transmission over a specific connection. 709 slicerClass = StorageBrokerRootSlicer 710 unslicerClass = StorageBrokerRootUnslicer 711 object = None 712 violation = None 713 disconnectReason = None 714 use_remote_broker = False 715 716 def prepare(self): 717 self.d = defer.Deferred() 718 return self.d 719 720 def receiveChild(self, obj, ready_deferred): 721 if ready_deferred: 722 ready_deferred.addBoth(self.d.callback) 723 self.d.addCallback(lambda res: obj) 724 else: 725 self.d.callback(obj) 726 del self.d 727 728 def reportViolation(self, why): 729 self.violation = why 730 eventually(self.d.callback, None) 731 return None 732 733 def reportReceiveError(self, f): 734 self.disconnectReason = f 735 f.raiseException() 736 737# this loopback stuff is based upon twisted.protocols.loopback, except that 738# we use it for real, not just for testing. The IConsumer stuff hasn't been 739# tested at all. 740 741@implementer(twinterfaces.IAddress) 742class LoopbackAddress(object): 743 pass 744 745@implementer(twinterfaces.ITransport, twinterfaces.IConsumer) 746class LoopbackTransport(object): 747 # we always create these in pairs, with .peer pointing at each other 748 749 producer = None 750 751 def __init__(self): 752 self.connected = True 753 def setPeer(self, peer): 754 self.peer = peer 755 756 def write(self, bytes): 757 eventually(self.peer.dataReceived, bytes) 758 def writeSequence(self, iovec): 759 self.write(''.join(iovec)) 760 761 def dataReceived(self, data): 762 if self.connected: 763 self.protocol.dataReceived(data) 764 765 def loseConnection(self, _connDone=connectionDone): 766 if not self.connected: 767 return 768 self.connected = False 769 eventually(self.peer.connectionLost, _connDone) 770 eventually(self.protocol.connectionLost, _connDone) 771 def connectionLost(self, reason): 772 if not self.connected: 773 return 774 self.connected = False 775 self.protocol.connectionLost(reason) 776 777 def getPeer(self): 778 return LoopbackAddress() 779 def getHost(self): 780 return LoopbackAddress() 781 782 # IConsumer 783 def registerProducer(self, producer, streaming): 784 assert self.producer is None 785 self.producer = producer 786 self.streamingProducer = streaming 787 self._pollProducer() 788 789 def unregisterProducer(self): 790 assert self.producer is not None 791 self.producer = None 792 793 def _pollProducer(self): 794 if self.producer is not None and not self.streamingProducer: 795 self.producer.resumeProducing() 796