1# -*- test-case-name: foolscap.test.test_copyable -*-
2
3# this module is responsible for all copy-by-value objects
4import six
5from zope.interface import interface, implementer
6from twisted.python import reflect, log
7from twisted.python.components import registerAdapter
8from twisted.internet import defer
9
10from . import slicer, tokens
11from .tokens import BananaError, Violation
12from foolscap.constraint import OpenerConstraint, IConstraint, Optional
13
14Interface = interface.Interface
15
16############################################################
17# the first half of this file is sending/serialization
18
19class ICopyable(Interface):
20    """I represent an object which is passed-by-value across PB connections.
21    """
22
23    def getTypeToCopy():
24        """Return a string which names the class. This string must match the
25        one that gets registered at the receiving end. This is typically a
26        URL of some sort, in a namespace which you control."""
27    def getStateToCopy():
28        """Return a state dictionary (with plain-string keys) which will be
29        serialized and sent to the remote end. This state object will be
30        given to the receiving object's setCopyableState method."""
31
32@implementer(ICopyable)
33class Copyable(object):
34    # you *must* set 'typeToCopy'
35
36    def getTypeToCopy(self):
37        try:
38            copytype = self.typeToCopy
39        except AttributeError:
40            raise RuntimeError("Copyable subclasses must specify 'typeToCopy'")
41        return copytype
42    def getStateToCopy(self):
43        return self.__dict__
44
45class CopyableSlicer(slicer.BaseSlicer):
46    """I handle ICopyable objects (things which are copied by value)."""
47    def slice(self, streamable, banana):
48        self.streamable = streamable
49        yield b'copyable'
50        copytype = self.obj.getTypeToCopy()
51        assert isinstance(copytype, str)
52        yield six.ensure_binary(copytype)
53        state = self.obj.getStateToCopy()
54        for k,v in state.items():
55            yield six.ensure_binary(k)
56            yield v
57    def describe(self):
58        return "<%s>" % self.obj.getTypeToCopy()
59registerAdapter(CopyableSlicer, ICopyable, tokens.ISlicer)
60
61
62class Copyable2(slicer.BaseSlicer):
63    # I am my own Slicer. This has more methods than you'd usually want in a
64    # base class, but if you can't register an Adapter for a whole class
65    # hierarchy then you may have to use it.
66    def getTypeToCopy(self):
67        return reflect.qual(self.__class__)
68    def getStateToCopy(self):
69        return self.__dict__
70    def slice(self, streamable, banana):
71        self.streamable = streamable
72        yield b'instance'
73        yield six.ensure_binary(self.getTypeToCopy())
74        yield self.getStateToCopy()
75    def describe(self):
76        return "<%s>" % self.getTypeToCopy()
77
78#registerRemoteCopy(typename, factory)
79#registerUnslicer(typename, factory)
80
81def registerCopier(klass, copier):
82    """This is a shortcut for arranging to serialize third-party clases.
83    'copier' must be a callable which accepts an instance of the class you
84    want to serialize, and returns a tuple of (typename, state_dictionary).
85    If it returns a typename of None, the original class's fully-qualified
86    classname is used.
87    """
88    klassname = reflect.qual(klass)
89    @implementer(ICopyable)
90    class _CopierAdapter:
91        def __init__(self, original):
92            self.nameToCopy, self.state = copier(original)
93            if self.nameToCopy is None:
94                self.nameToCopy = klassname
95        def getTypeToCopy(self):
96            return self.nameToCopy
97        def getStateToCopy(self):
98            return self.state
99    registerAdapter(_CopierAdapter, klass, ICopyable)
100
101############################################################
102# beyond here is the receiving/deserialization side
103
104class RemoteCopyUnslicer(slicer.BaseUnslicer):
105    attrname = None
106    attrConstraint = None
107
108    def __init__(self, factory, stateSchema):
109        self.factory = factory
110        self.schema = stateSchema
111
112    def start(self, count):
113        self.d = {}
114        self.count = count
115        self.deferred = defer.Deferred()
116        self.protocol.setObject(count, self.deferred)
117
118    def checkToken(self, typebyte, size):
119        if self.attrname == None:
120            if typebyte not in (tokens.STRING, tokens.VOCAB):
121                raise BananaError("RemoteCopyUnslicer keys must be STRINGs")
122        else:
123            if self.attrConstraint:
124                self.attrConstraint.checkToken(typebyte, size)
125
126    def doOpen(self, opentype):
127        if self.attrConstraint:
128            self.attrConstraint.checkOpentype(opentype)
129        unslicer = self.open(opentype)
130        if unslicer:
131            if self.attrConstraint:
132                unslicer.setConstraint(self.attrConstraint)
133        return unslicer
134
135    def receiveChild(self, obj, ready_deferred=None):
136        assert not isinstance(obj, defer.Deferred)
137        assert ready_deferred is None
138        if self.attrname == None:
139            attrname = six.ensure_str(obj)
140            if attrname in self.d:
141                raise BananaError("duplicate attribute name '%s'" % attrname)
142            s = self.schema
143            if s:
144                accept, self.attrConstraint = s.getAttrConstraint(attrname)
145                assert accept
146            self.attrname = attrname
147        else:
148            if isinstance(obj, defer.Deferred):
149                # TODO: this is an artificial restriction, and it might
150                # be possible to remove it, but I need to think through
151                # it carefully first
152                raise BananaError("unreferenceable object in attribute")
153            self.setAttribute(self.attrname, obj)
154            self.attrname = None
155            self.attrConstraint = None
156
157    def setAttribute(self, name, value):
158        self.d[name] = value
159
160    def receiveClose(self):
161        try:
162            obj = self.factory(self.d)
163        except:
164            log.msg("%s.receiveClose: problem in factory %s" %
165                    (self.__class__.__name__, self.factory))
166            log.err()
167            raise
168        self.protocol.setObject(self.count, obj)
169        self.deferred.callback(obj)
170        return obj, None
171
172    def describe(self):
173        if self.classname == None:
174            return "<??>"
175        me = "<%s>" % self.classname
176        if self.attrname is None:
177            return "%s.attrname??" % me
178        else:
179            return "%s.%s" % (me, self.attrname)
180
181
182class NonCyclicRemoteCopyUnslicer(RemoteCopyUnslicer):
183    # The Deferred used in RemoteCopyUnslicer (used in case the RemoteCopy
184    # is participating in a reference cycle, say 'obj.foo = obj') makes it
185    # unsuitable for holding Failures (which cannot be passed through
186    # Deferred.callback). Use this class for Failures. It cannot handle
187    # reference cycles (they will cause a KeyError when the reference is
188    # followed).
189
190    def start(self, count):
191        self.d = {}
192        self.count = count
193        self.gettingAttrname = True
194
195    def receiveClose(self):
196        obj = self.factory(self.d)
197        return obj, None
198
199
200class IRemoteCopy(Interface):
201    """This interface defines what a RemoteCopy class must do. RemoteCopy
202    subclasses are used as factories to create objects that correspond to
203    Copyables sent over the wire.
204
205    Note that the constructor of an IRemoteCopy class will be called without
206    any arguments.
207    """
208
209    def setCopyableState(statedict):
210        """I accept an attribute dictionary name/value pairs and use it to
211        set my internal state.
212
213        Some of the values may be Deferreds, which are placeholders for the
214        as-yet-unreferenceable object which will eventually go there. If you
215        receive a Deferred, you are responsible for adding a callback to
216        update the attribute when it fires. [note:
217        RemoteCopyUnslicer.receiveChild currently has a restriction which
218        prevents this from happening, but that may go away in the future]
219
220        Some of the objects referenced by the attribute values may have
221        Deferreds in them (e.g. containers which reference recursive tuples).
222        Such containers are responsible for updating their own state when
223        those Deferreds fire, but until that point their state is still
224        subject to change. Therefore you must be careful about how much state
225        inspection you perform within this method."""
226
227    stateSchema = interface.Attribute("""I return an AttributeDictConstraint
228    object which places restrictions on incoming attribute values. These
229    restrictions are enforced as the tokens are received, before the state is
230    passed to setCopyableState.""")
231
232
233# This maps typename to an Unslicer factory
234CopyableRegistry = {}
235def registerRemoteCopyUnslicerFactory(typename, unslicerfactory,
236                                      registry=None):
237    """Tell PB that unslicerfactory can be used to handle Copyable objects
238    that provide a getTypeToCopy name of 'typename'. 'unslicerfactory' must
239    be a callable which takes no arguments and returns an object which
240    provides IUnslicer.
241    """
242    assert callable(unslicerfactory)
243    # in addition, it must produce a tokens.IUnslicer . This is safe to do
244    # because Unslicers don't do anything significant when they are created.
245    test_unslicer = unslicerfactory()
246    assert tokens.IUnslicer.providedBy(test_unslicer)
247    assert type(typename) is str
248
249    if registry == None:
250        registry = CopyableRegistry
251    assert typename not in registry
252    registry[typename] = unslicerfactory
253
254# this keeps track of everything submitted to registerRemoteCopyFactory
255debug_CopyableFactories = {}
256def registerRemoteCopyFactory(typename, factory, stateSchema=None,
257                              cyclic=True, registry=None):
258    """Tell PB that 'factory' can be used to handle Copyable objects that
259    provide a getTypeToCopy name of 'typename'. 'factory' must be a callable
260    which accepts a state dictionary and returns a fully-formed instance.
261
262    'cyclic' is a boolean, which should be set to False to avoid using a
263    Deferred to provide the resulting RemoteCopy instance. This is needed to
264    deserialize Failures (or instances which inherit from one, like
265    CopiedFailure). In exchange for this, it cannot handle reference cycles.
266    """
267    assert callable(factory)
268    debug_CopyableFactories[typename] = (factory, stateSchema, cyclic)
269    if cyclic:
270        def _RemoteCopyUnslicerFactory():
271            return RemoteCopyUnslicer(factory, stateSchema)
272        registerRemoteCopyUnslicerFactory(typename,
273                                          _RemoteCopyUnslicerFactory,
274                                          registry)
275    else:
276        def _RemoteCopyUnslicerFactoryNonCyclic():
277            return NonCyclicRemoteCopyUnslicer(factory, stateSchema)
278        registerRemoteCopyUnslicerFactory(typename,
279                                          _RemoteCopyUnslicerFactoryNonCyclic,
280                                          registry)
281
282# this keeps track of everything submitted to registerRemoteCopy, which may
283# be useful when you're wondering what's been auto-registered by the
284# RemoteCopy metaclass magic
285debug_RemoteCopyClasses = {}
286def registerRemoteCopy(typename, remote_copy_class, registry=None):
287    """Tell PB that remote_copy_class is the appropriate RemoteCopy class to
288    use when deserializing a Copyable sequence that is tagged with
289    'typename'. 'remote_copy_class' should be a RemoteCopy subclass or
290    implement the same interface, which means its constructor takes no
291    arguments and it has a setCopyableState(state) method to actually set the
292    instance's state after initialization. It must also have a nonCyclic
293    attribute.
294    """
295    assert IRemoteCopy.implementedBy(remote_copy_class)
296    assert type(typename) is str
297
298    debug_RemoteCopyClasses[typename] = remote_copy_class
299    def _RemoteCopyFactory(state):
300        obj = remote_copy_class()
301        obj.setCopyableState(state)
302        return obj
303
304    registerRemoteCopyFactory(typename, _RemoteCopyFactory,
305                              remote_copy_class.stateSchema,
306                              not remote_copy_class.nonCyclic,
307                              registry)
308
309class RemoteCopyClass(type):
310    # auto-register RemoteCopy classes
311    def __init__(self, name, bases, dict):
312        type.__init__(self, name, bases, dict)
313        # don't try to register RemoteCopy itself
314        if name == "RemoteCopy" and _RemoteCopyBase in bases:
315            #print "not auto-registering %s %s" % (name, bases)
316            return
317        if "copytype" not in dict:
318            # TODO: provide a file/line-number for the class
319            raise RuntimeError("RemoteCopy subclass %s must specify 'copytype'"
320                               % name)
321        copytype = dict['copytype']
322        if copytype:
323            registry = dict.get('copyableRegistry', None)
324            registerRemoteCopy(copytype, self, registry)
325
326@implementer(IRemoteCopy)
327class _RemoteCopyBase:
328    stateSchema = None # always a class attribute
329    nonCyclic = False
330
331    def __init__(self):
332        # the constructor will always be called without arguments
333        pass
334
335    def setCopyableState(self, state):
336        self.__dict__ = state
337
338class RemoteCopyOldStyle(_RemoteCopyBase):
339    # note that these will not auto-register for you, because old-style
340    # classes do not do metaclass magic
341    copytype = None
342
343@six.add_metaclass(RemoteCopyClass)
344class RemoteCopy(_RemoteCopyBase, object):
345    # Set 'copytype' to a unique string that is shared between the
346    # sender-side Copyable and the receiver-side RemoteCopy. This RemoteCopy
347    # subclass will be auto-registered using the 'copytype' name. Set
348    # copytype to None to disable auto-registration.
349    pass
350
351
352class AttributeDictConstraint(OpenerConstraint):
353    """This is a constraint for dictionaries that are used for attributes.
354    All keys are short strings, and each value has a separate constraint.
355    It could be used to describe instance state, but could also be used
356    to constraint arbitrary dictionaries with string keys.
357
358    Some special constraints are legal here: Optional.
359    """
360    opentypes = [("attrdict",)]
361    name = "AttributeDictConstraint"
362
363    def __init__(self, *attrTuples, **kwargs):
364        self.ignoreUnknown = kwargs.get('ignoreUnknown', False)
365        self.acceptUnknown = kwargs.get('acceptUnknown', False)
366        self.keys = {}
367        for name, constraint in (list(attrTuples) +
368                                 list(kwargs.get('attributes', {}).items())):
369            assert name not in list(self.keys.keys())
370            self.keys[name] = IConstraint(constraint)
371
372    def getAttrConstraint(self, attrname):
373        c = self.keys.get(attrname)
374        if c:
375            if isinstance(c, Optional):
376                c = c.constraint
377            return (True, c)
378        # unknown attribute
379        if self.ignoreUnknown:
380            return (False, None)
381        if self.acceptUnknown:
382            return (True, None)
383        raise Violation("unknown attribute '%s'" % attrname)
384
385    def checkObject(self, obj, inbound):
386        if type(obj) != type({}):
387            raise Violation("'%s' (%s) is not a Dictionary" % (obj,
388                                                               type(obj)))
389        allkeys = list(self.keys.keys())
390        for k in list(obj.keys()):
391            try:
392                constraint = self.keys[k]
393                allkeys.remove(k)
394            except KeyError:
395                if not self.ignoreUnknown:
396                    raise Violation("key '%s' not in schema" % k)
397                else:
398                    # hmm. kind of a soft violation. allow it for now.
399                    pass
400            else:
401                constraint.checkObject(obj[k], inbound)
402
403        for k in allkeys[:]:
404            if isinstance(self.keys[k], Optional):
405                allkeys.remove(k)
406        if allkeys:
407            raise Violation("object is missing required keys: %s" % \
408                            ",".join(allkeys))
409
410