1# -*- test-case-name: foolscap.test.test_banana -*-
2
3from __future__ import print_function
4import six
5from zope.interface import implementer
6from twisted.internet.defer import Deferred
7from foolscap import tokens
8from foolscap.tokens import Violation, BananaError
9from foolscap.slicer import BaseUnslicer, ReferenceSlicer
10from foolscap.slicer import UnslicerRegistry, BananaUnslicerRegistry
11from foolscap.slicers.vocab import ReplaceVocabularyTable, AddToVocabularyTable
12from foolscap.util import ensure_tuple_str
13from foolscap import copyable # does this create a cycle?
14from twisted.python import log
15from functools import reduce
16
17@implementer(tokens.ISlicer, tokens.IRootSlicer)
18class RootSlicer:
19    streamableInGeneral = True
20    producingDeferred = None
21    objectSentDeferred = None
22    slicerTable = {}
23    debug = False
24
25    def __init__(self, protocol):
26        self.protocol = protocol
27        self.sendQueue = []
28
29    def allowStreaming(self, streamable):
30        self.streamableInGeneral = streamable
31
32    def registerRefID(self, refid, obj):
33        pass
34
35    def slicerForObject(self, obj):
36        # could use a table here if you think it'd be faster than an
37        # adapter lookup
38        if self.debug: log.msg("slicerForObject(%s)" % type(obj))
39
40        # do the adapter lookup first, so that registered adapters override
41        # UnsafeSlicerTable's InstanceSlicer
42        slicer = tokens.ISlicer(obj, None)
43        if slicer:
44            if self.debug: log.msg("got ISlicer %s" % slicer)
45            return slicer
46
47        # zope.interface doesn't do transitive adaptation, which is a shame
48        # because we want to let people register ICopyable adapters for
49        # third-party code, and there is an ICopyable->ISlicer adapter
50        # defined in copyable.py, but z.i won't do the transitive
51        #  ThirdPartyClass -> ICopyable -> ISlicer
52        # so instead we manually do it here
53
54        copier = copyable.ICopyable(obj, None)
55        if copier:
56            s = tokens.ISlicer(copier)
57            return s
58
59        slicerFactory = self.slicerTable.get(type(obj))
60        if slicerFactory:
61            if self.debug: log.msg(" got slicerFactory %s" % slicerFactory)
62            return slicerFactory(obj)
63        name = str(type(obj))
64        if self.debug: log.msg("cannot serialize %s (%s)" % (obj, name))
65        raise Violation("cannot serialize %s (%s)" % (obj, name))
66
67    sliceAlreadyCalled = False
68    def slice(self):
69        # this may only be called once
70        assert not self.sliceAlreadyCalled
71        self.sliceAlreadyCalled = True
72        return iter(self)
73
74    def __iter__(self):
75        return self
76
77    def __next__(self):
78        if self.objectSentDeferred:
79            self.objectSentDeferred.callback(None)
80            self.objectSentDeferred = None
81        if self.sendQueue:
82            (obj, self.objectSentDeferred) = self.sendQueue.pop()
83            self.streamable = self.streamableInGeneral
84            return obj
85        if self.protocol.debugSend:
86            print("LAST BAG")
87        self.producingDeferred = Deferred()
88        self.streamable = True
89        return self.producingDeferred
90    next = __next__
91
92    def childAborted(self, f):
93        assert self.objectSentDeferred
94        self.objectSentDeferred.errback(f)
95        self.objectSentDeferred = None
96        return None
97
98    def send(self, obj):
99        # obj can also be a Slicer, say, a CallSlicer. We return a Deferred
100        # which fires when the object has been fully serialized.
101        idle = (len(self.protocol.slicerStack) == 1) and not self.sendQueue
102        objectSentDeferred = Deferred()
103        self.sendQueue.append((obj, objectSentDeferred))
104        if idle:
105            # wake up
106            if self.protocol.debugSend:
107                print(" waking up to send")
108            if self.producingDeferred:
109                d = self.producingDeferred
110                self.producingDeferred = None
111                # TODO: consider reactor.callLater(0, d.callback, None)
112                # I'm not sure it's actually necessary, though
113                d.callback(None)
114        return objectSentDeferred
115
116    def describe(self):
117        return "<RootSlicer>"
118
119    def connectionLost(self, why):
120        # abandon everything we wanted to send
121        if self.objectSentDeferred:
122            self.objectSentDeferred.errback(why)
123            self.objectSentDeferred = None
124        for obj, d in self.sendQueue:
125            d.errback(why)
126        self.sendQueue = []
127
128class ScopedRootSlicer(RootSlicer):
129    # this combines RootSlicer with foolscap.slicer.ScopedSlicer . The funny
130    # self-delegation of slicerForObject() means we can't just inherit from
131    # both. It would be nice to refactor everything to make this cleaner.
132
133    def __init__(self, obj):
134        RootSlicer.__init__(self, obj)
135        self.references = {} # maps id(obj) -> (obj,refid)
136
137    def registerRefID(self, refid, obj):
138        self.references[id(obj)] = (obj,refid)
139
140    def slicerForObject(self, obj):
141        # check for an object which was sent previously or has at least
142        # started sending
143        obj_refid = self.references.get(id(obj), None)
144        if obj_refid is not None:
145            # we've started to send this object already, so just include a
146            # reference to it
147            return ReferenceSlicer(obj_refid[1])
148        # otherwise go upstream so we can serialize the object completely
149        return RootSlicer.slicerForObject(self, obj)
150
151
152
153class RootUnslicer(BaseUnslicer):
154    # topRegistries is used for top-level objects
155    topRegistries = [UnslicerRegistry, BananaUnslicerRegistry]
156    # openRegistries is used for everything at lower levels
157    openRegistries = [UnslicerRegistry]
158    constraint = None
159    openCount = None
160
161    def __init__(self, protocol):
162        self.protocol = protocol
163        self.objects = {}
164        keys = []
165        for r in self.topRegistries + self.openRegistries:
166            for k in list(r.keys()):
167                keys.append(len(k[0]))
168        self.maxIndexLength = reduce(max, keys)
169
170    def start(self, count):
171        pass
172
173    def setConstraint(self, constraint):
174        # this constraints top-level objects. E.g., if this is an
175        # IntegerConstraint, then only integers will be accepted.
176        self.constraint = constraint
177
178    def checkToken(self, typebyte, size):
179        if self.constraint:
180            self.constraint.checkToken(typebyte, size)
181
182    def openerCheckToken(self, typebyte, size, opentype):
183        if typebyte == tokens.STRING:
184            if size > self.maxIndexLength:
185                why = "STRING token is too long, %d>%d" % \
186                      (size, self.maxIndexLength)
187                raise Violation(why)
188        elif typebyte == tokens.VOCAB:
189            return
190        else:
191            # TODO: hack for testing
192            raise Violation("index token 0x%02x not STRING or VOCAB" % \
193                              six.byte2int(typebyte))
194            raise BananaError("index token 0x%02x not STRING or VOCAB" % \
195                              six.byte2int(typebyte))
196
197    def open(self, opentype):
198        # called (by delegation) by the top Unslicer on the stack, regardless
199        # of what kind of unslicer it is. This is only used for "internal"
200        # objects: non-top-level nodes
201        assert len(self.protocol.receiveStack) > 1
202        opentype = ensure_tuple_str(opentype)
203
204        if opentype[0] == 'copyable':
205            if len(opentype) > 1:
206                copyablename = opentype[1]
207                try:
208                    factory = copyable.CopyableRegistry[copyablename]
209                except KeyError:
210                    raise Violation("unknown RemoteCopy name '%s'" \
211                                    % copyablename)
212                child = factory()
213                return child
214            return None # still waiting for copyablename
215
216        for reg in self.openRegistries:
217            opener = reg.get(opentype)
218            if opener is not None:
219                child = opener()
220                return child
221
222        raise Violation("unknown OPEN type %s" % (opentype,))
223
224    def doOpen(self, opentype):
225        # this is only called for top-level objects
226        assert len(self.protocol.receiveStack) == 1
227        opentype = ensure_tuple_str(opentype)
228        if self.constraint:
229            self.constraint.checkOpentype(opentype)
230        for reg in self.topRegistries:
231            opener = reg.get(opentype)
232            if opener is not None:
233                child = opener()
234                break
235        else:
236            raise Violation("unknown top-level OPEN type %s" % (opentype,))
237
238        if self.constraint:
239            child.setConstraint(self.constraint)
240        return child
241
242    def receiveChild(self, obj, ready_deferred=None):
243        assert not isinstance(obj, Deferred)
244        assert ready_deferred is None
245        if self.protocol.debugReceive:
246            print("RootUnslicer.receiveChild(%s)" % (obj,))
247        self.objects = {}
248        if obj in (ReplaceVocabularyTable, AddToVocabularyTable):
249            # the unslicer has already changed the vocab table
250            return
251        if self.protocol.exploded:
252            print("protocol exploded, can't deliver object")
253            print(self.protocol.exploded)
254            self.protocol.receivedObject(self.protocol.exploded)
255            return
256        self.protocol.receivedObject(obj) # give finished object to Banana
257
258    def receiveClose(self):
259        raise BananaError("top-level should never receive CLOSE tokens")
260
261    def reportViolation(self, why):
262        return self.protocol.reportViolation(why)
263
264    def describe(self):
265        return "<RootUnslicer>"
266
267    def setObject(self, counter, obj):
268        pass
269
270    def getObject(self, counter):
271        return None
272
273class ScopedRootUnslicer(RootUnslicer):
274    # combines RootUnslicer and ScopedUnslicer
275
276    def __init__(self, protocol):
277        RootUnslicer.__init__(self, protocol)
278        self.references = {}
279
280    def setObject(self, counter, obj):
281        self.references[counter] = obj
282
283    def getObject(self, counter):
284        obj = self.references.get(counter)
285        return obj
286