1# -*- test-case-name: foolscap.test.test_copyable -*- 2 3# this module is responsible for all copy-by-value objects 4import six 5from zope.interface import interface, implementer 6from twisted.python import reflect, log 7from twisted.python.components import registerAdapter 8from twisted.internet import defer 9 10from . import slicer, tokens 11from .tokens import BananaError, Violation 12from foolscap.constraint import OpenerConstraint, IConstraint, Optional 13 14Interface = interface.Interface 15 16############################################################ 17# the first half of this file is sending/serialization 18 19class ICopyable(Interface): 20 """I represent an object which is passed-by-value across PB connections. 21 """ 22 23 def getTypeToCopy(): 24 """Return a string which names the class. This string must match the 25 one that gets registered at the receiving end. This is typically a 26 URL of some sort, in a namespace which you control.""" 27 def getStateToCopy(): 28 """Return a state dictionary (with plain-string keys) which will be 29 serialized and sent to the remote end. This state object will be 30 given to the receiving object's setCopyableState method.""" 31 32@implementer(ICopyable) 33class Copyable(object): 34 # you *must* set 'typeToCopy' 35 36 def getTypeToCopy(self): 37 try: 38 copytype = self.typeToCopy 39 except AttributeError: 40 raise RuntimeError("Copyable subclasses must specify 'typeToCopy'") 41 return copytype 42 def getStateToCopy(self): 43 return self.__dict__ 44 45class CopyableSlicer(slicer.BaseSlicer): 46 """I handle ICopyable objects (things which are copied by value).""" 47 def slice(self, streamable, banana): 48 self.streamable = streamable 49 yield b'copyable' 50 copytype = self.obj.getTypeToCopy() 51 assert isinstance(copytype, str) 52 yield six.ensure_binary(copytype) 53 state = self.obj.getStateToCopy() 54 for k,v in state.items(): 55 yield six.ensure_binary(k) 56 yield v 57 def describe(self): 58 return "<%s>" % self.obj.getTypeToCopy() 59registerAdapter(CopyableSlicer, ICopyable, tokens.ISlicer) 60 61 62class Copyable2(slicer.BaseSlicer): 63 # I am my own Slicer. This has more methods than you'd usually want in a 64 # base class, but if you can't register an Adapter for a whole class 65 # hierarchy then you may have to use it. 66 def getTypeToCopy(self): 67 return reflect.qual(self.__class__) 68 def getStateToCopy(self): 69 return self.__dict__ 70 def slice(self, streamable, banana): 71 self.streamable = streamable 72 yield b'instance' 73 yield six.ensure_binary(self.getTypeToCopy()) 74 yield self.getStateToCopy() 75 def describe(self): 76 return "<%s>" % self.getTypeToCopy() 77 78#registerRemoteCopy(typename, factory) 79#registerUnslicer(typename, factory) 80 81def registerCopier(klass, copier): 82 """This is a shortcut for arranging to serialize third-party clases. 83 'copier' must be a callable which accepts an instance of the class you 84 want to serialize, and returns a tuple of (typename, state_dictionary). 85 If it returns a typename of None, the original class's fully-qualified 86 classname is used. 87 """ 88 klassname = reflect.qual(klass) 89 @implementer(ICopyable) 90 class _CopierAdapter: 91 def __init__(self, original): 92 self.nameToCopy, self.state = copier(original) 93 if self.nameToCopy is None: 94 self.nameToCopy = klassname 95 def getTypeToCopy(self): 96 return self.nameToCopy 97 def getStateToCopy(self): 98 return self.state 99 registerAdapter(_CopierAdapter, klass, ICopyable) 100 101############################################################ 102# beyond here is the receiving/deserialization side 103 104class RemoteCopyUnslicer(slicer.BaseUnslicer): 105 attrname = None 106 attrConstraint = None 107 108 def __init__(self, factory, stateSchema): 109 self.factory = factory 110 self.schema = stateSchema 111 112 def start(self, count): 113 self.d = {} 114 self.count = count 115 self.deferred = defer.Deferred() 116 self.protocol.setObject(count, self.deferred) 117 118 def checkToken(self, typebyte, size): 119 if self.attrname == None: 120 if typebyte not in (tokens.STRING, tokens.VOCAB): 121 raise BananaError("RemoteCopyUnslicer keys must be STRINGs") 122 else: 123 if self.attrConstraint: 124 self.attrConstraint.checkToken(typebyte, size) 125 126 def doOpen(self, opentype): 127 if self.attrConstraint: 128 self.attrConstraint.checkOpentype(opentype) 129 unslicer = self.open(opentype) 130 if unslicer: 131 if self.attrConstraint: 132 unslicer.setConstraint(self.attrConstraint) 133 return unslicer 134 135 def receiveChild(self, obj, ready_deferred=None): 136 assert not isinstance(obj, defer.Deferred) 137 assert ready_deferred is None 138 if self.attrname == None: 139 attrname = six.ensure_str(obj) 140 if attrname in self.d: 141 raise BananaError("duplicate attribute name '%s'" % attrname) 142 s = self.schema 143 if s: 144 accept, self.attrConstraint = s.getAttrConstraint(attrname) 145 assert accept 146 self.attrname = attrname 147 else: 148 if isinstance(obj, defer.Deferred): 149 # TODO: this is an artificial restriction, and it might 150 # be possible to remove it, but I need to think through 151 # it carefully first 152 raise BananaError("unreferenceable object in attribute") 153 self.setAttribute(self.attrname, obj) 154 self.attrname = None 155 self.attrConstraint = None 156 157 def setAttribute(self, name, value): 158 self.d[name] = value 159 160 def receiveClose(self): 161 try: 162 obj = self.factory(self.d) 163 except: 164 log.msg("%s.receiveClose: problem in factory %s" % 165 (self.__class__.__name__, self.factory)) 166 log.err() 167 raise 168 self.protocol.setObject(self.count, obj) 169 self.deferred.callback(obj) 170 return obj, None 171 172 def describe(self): 173 if self.classname == None: 174 return "<??>" 175 me = "<%s>" % self.classname 176 if self.attrname is None: 177 return "%s.attrname??" % me 178 else: 179 return "%s.%s" % (me, self.attrname) 180 181 182class NonCyclicRemoteCopyUnslicer(RemoteCopyUnslicer): 183 # The Deferred used in RemoteCopyUnslicer (used in case the RemoteCopy 184 # is participating in a reference cycle, say 'obj.foo = obj') makes it 185 # unsuitable for holding Failures (which cannot be passed through 186 # Deferred.callback). Use this class for Failures. It cannot handle 187 # reference cycles (they will cause a KeyError when the reference is 188 # followed). 189 190 def start(self, count): 191 self.d = {} 192 self.count = count 193 self.gettingAttrname = True 194 195 def receiveClose(self): 196 obj = self.factory(self.d) 197 return obj, None 198 199 200class IRemoteCopy(Interface): 201 """This interface defines what a RemoteCopy class must do. RemoteCopy 202 subclasses are used as factories to create objects that correspond to 203 Copyables sent over the wire. 204 205 Note that the constructor of an IRemoteCopy class will be called without 206 any arguments. 207 """ 208 209 def setCopyableState(statedict): 210 """I accept an attribute dictionary name/value pairs and use it to 211 set my internal state. 212 213 Some of the values may be Deferreds, which are placeholders for the 214 as-yet-unreferenceable object which will eventually go there. If you 215 receive a Deferred, you are responsible for adding a callback to 216 update the attribute when it fires. [note: 217 RemoteCopyUnslicer.receiveChild currently has a restriction which 218 prevents this from happening, but that may go away in the future] 219 220 Some of the objects referenced by the attribute values may have 221 Deferreds in them (e.g. containers which reference recursive tuples). 222 Such containers are responsible for updating their own state when 223 those Deferreds fire, but until that point their state is still 224 subject to change. Therefore you must be careful about how much state 225 inspection you perform within this method.""" 226 227 stateSchema = interface.Attribute("""I return an AttributeDictConstraint 228 object which places restrictions on incoming attribute values. These 229 restrictions are enforced as the tokens are received, before the state is 230 passed to setCopyableState.""") 231 232 233# This maps typename to an Unslicer factory 234CopyableRegistry = {} 235def registerRemoteCopyUnslicerFactory(typename, unslicerfactory, 236 registry=None): 237 """Tell PB that unslicerfactory can be used to handle Copyable objects 238 that provide a getTypeToCopy name of 'typename'. 'unslicerfactory' must 239 be a callable which takes no arguments and returns an object which 240 provides IUnslicer. 241 """ 242 assert callable(unslicerfactory) 243 # in addition, it must produce a tokens.IUnslicer . This is safe to do 244 # because Unslicers don't do anything significant when they are created. 245 test_unslicer = unslicerfactory() 246 assert tokens.IUnslicer.providedBy(test_unslicer) 247 assert type(typename) is str 248 249 if registry == None: 250 registry = CopyableRegistry 251 assert typename not in registry 252 registry[typename] = unslicerfactory 253 254# this keeps track of everything submitted to registerRemoteCopyFactory 255debug_CopyableFactories = {} 256def registerRemoteCopyFactory(typename, factory, stateSchema=None, 257 cyclic=True, registry=None): 258 """Tell PB that 'factory' can be used to handle Copyable objects that 259 provide a getTypeToCopy name of 'typename'. 'factory' must be a callable 260 which accepts a state dictionary and returns a fully-formed instance. 261 262 'cyclic' is a boolean, which should be set to False to avoid using a 263 Deferred to provide the resulting RemoteCopy instance. This is needed to 264 deserialize Failures (or instances which inherit from one, like 265 CopiedFailure). In exchange for this, it cannot handle reference cycles. 266 """ 267 assert callable(factory) 268 debug_CopyableFactories[typename] = (factory, stateSchema, cyclic) 269 if cyclic: 270 def _RemoteCopyUnslicerFactory(): 271 return RemoteCopyUnslicer(factory, stateSchema) 272 registerRemoteCopyUnslicerFactory(typename, 273 _RemoteCopyUnslicerFactory, 274 registry) 275 else: 276 def _RemoteCopyUnslicerFactoryNonCyclic(): 277 return NonCyclicRemoteCopyUnslicer(factory, stateSchema) 278 registerRemoteCopyUnslicerFactory(typename, 279 _RemoteCopyUnslicerFactoryNonCyclic, 280 registry) 281 282# this keeps track of everything submitted to registerRemoteCopy, which may 283# be useful when you're wondering what's been auto-registered by the 284# RemoteCopy metaclass magic 285debug_RemoteCopyClasses = {} 286def registerRemoteCopy(typename, remote_copy_class, registry=None): 287 """Tell PB that remote_copy_class is the appropriate RemoteCopy class to 288 use when deserializing a Copyable sequence that is tagged with 289 'typename'. 'remote_copy_class' should be a RemoteCopy subclass or 290 implement the same interface, which means its constructor takes no 291 arguments and it has a setCopyableState(state) method to actually set the 292 instance's state after initialization. It must also have a nonCyclic 293 attribute. 294 """ 295 assert IRemoteCopy.implementedBy(remote_copy_class) 296 assert type(typename) is str 297 298 debug_RemoteCopyClasses[typename] = remote_copy_class 299 def _RemoteCopyFactory(state): 300 obj = remote_copy_class() 301 obj.setCopyableState(state) 302 return obj 303 304 registerRemoteCopyFactory(typename, _RemoteCopyFactory, 305 remote_copy_class.stateSchema, 306 not remote_copy_class.nonCyclic, 307 registry) 308 309class RemoteCopyClass(type): 310 # auto-register RemoteCopy classes 311 def __init__(self, name, bases, dict): 312 type.__init__(self, name, bases, dict) 313 # don't try to register RemoteCopy itself 314 if name == "RemoteCopy" and _RemoteCopyBase in bases: 315 #print "not auto-registering %s %s" % (name, bases) 316 return 317 if "copytype" not in dict: 318 # TODO: provide a file/line-number for the class 319 raise RuntimeError("RemoteCopy subclass %s must specify 'copytype'" 320 % name) 321 copytype = dict['copytype'] 322 if copytype: 323 registry = dict.get('copyableRegistry', None) 324 registerRemoteCopy(copytype, self, registry) 325 326@implementer(IRemoteCopy) 327class _RemoteCopyBase: 328 stateSchema = None # always a class attribute 329 nonCyclic = False 330 331 def __init__(self): 332 # the constructor will always be called without arguments 333 pass 334 335 def setCopyableState(self, state): 336 self.__dict__ = state 337 338class RemoteCopyOldStyle(_RemoteCopyBase): 339 # note that these will not auto-register for you, because old-style 340 # classes do not do metaclass magic 341 copytype = None 342 343@six.add_metaclass(RemoteCopyClass) 344class RemoteCopy(_RemoteCopyBase, object): 345 # Set 'copytype' to a unique string that is shared between the 346 # sender-side Copyable and the receiver-side RemoteCopy. This RemoteCopy 347 # subclass will be auto-registered using the 'copytype' name. Set 348 # copytype to None to disable auto-registration. 349 pass 350 351 352class AttributeDictConstraint(OpenerConstraint): 353 """This is a constraint for dictionaries that are used for attributes. 354 All keys are short strings, and each value has a separate constraint. 355 It could be used to describe instance state, but could also be used 356 to constraint arbitrary dictionaries with string keys. 357 358 Some special constraints are legal here: Optional. 359 """ 360 opentypes = [("attrdict",)] 361 name = "AttributeDictConstraint" 362 363 def __init__(self, *attrTuples, **kwargs): 364 self.ignoreUnknown = kwargs.get('ignoreUnknown', False) 365 self.acceptUnknown = kwargs.get('acceptUnknown', False) 366 self.keys = {} 367 for name, constraint in (list(attrTuples) + 368 list(kwargs.get('attributes', {}).items())): 369 assert name not in list(self.keys.keys()) 370 self.keys[name] = IConstraint(constraint) 371 372 def getAttrConstraint(self, attrname): 373 c = self.keys.get(attrname) 374 if c: 375 if isinstance(c, Optional): 376 c = c.constraint 377 return (True, c) 378 # unknown attribute 379 if self.ignoreUnknown: 380 return (False, None) 381 if self.acceptUnknown: 382 return (True, None) 383 raise Violation("unknown attribute '%s'" % attrname) 384 385 def checkObject(self, obj, inbound): 386 if type(obj) != type({}): 387 raise Violation("'%s' (%s) is not a Dictionary" % (obj, 388 type(obj))) 389 allkeys = list(self.keys.keys()) 390 for k in list(obj.keys()): 391 try: 392 constraint = self.keys[k] 393 allkeys.remove(k) 394 except KeyError: 395 if not self.ignoreUnknown: 396 raise Violation("key '%s' not in schema" % k) 397 else: 398 # hmm. kind of a soft violation. allow it for now. 399 pass 400 else: 401 constraint.checkObject(obj[k], inbound) 402 403 for k in allkeys[:]: 404 if isinstance(self.keys[k], Optional): 405 allkeys.remove(k) 406 if allkeys: 407 raise Violation("object is missing required keys: %s" % \ 408 ",".join(allkeys)) 409 410