1# -*- test-case-name: foolscap.test.test_banana -*-
2
3from twisted.python import log
4from twisted.internet.defer import Deferred
5from foolscap.tokens import Violation, BananaError
6from foolscap.slicer import BaseSlicer, BaseUnslicer
7from foolscap.constraint import OpenerConstraint, Any, IConstraint
8from foolscap.util import AsyncAND
9
10class DictSlicer(BaseSlicer):
11    opentype = ('dict',)
12    trackReferences = True
13    slices = None
14    def sliceBody(self, streamable, banana):
15        for key,value in list(self.obj.items()):
16            yield key
17            yield value
18
19class DictUnslicer(BaseUnslicer):
20    opentype = ('dict',)
21
22    gettingKey = True
23    keyConstraint = None
24    valueConstraint = None
25    maxKeys = None
26
27    def setConstraint(self, constraint):
28        if isinstance(constraint, Any):
29            return
30        assert isinstance(constraint, DictConstraint)
31        self.keyConstraint = constraint.keyConstraint
32        self.valueConstraint = constraint.valueConstraint
33        self.maxKeys = constraint.maxKeys
34
35    def start(self, count):
36        self.d = {}
37        self.protocol.setObject(count, self.d)
38        self.key = None
39        self._ready_deferreds = []
40
41    def checkToken(self, typebyte, size):
42        if self.maxKeys != None:
43            if len(self.d) >= self.maxKeys:
44                raise Violation("the dict is full")
45        if self.gettingKey:
46            if self.keyConstraint:
47                self.keyConstraint.checkToken(typebyte, size)
48        else:
49            if self.valueConstraint:
50                self.valueConstraint.checkToken(typebyte, size)
51
52    def doOpen(self, opentype):
53        if self.maxKeys != None:
54            if len(self.d) >= self.maxKeys:
55                raise Violation("the dict is full")
56        if self.gettingKey:
57            if self.keyConstraint:
58                self.keyConstraint.checkOpentype(opentype)
59        else:
60            if self.valueConstraint:
61                self.valueConstraint.checkOpentype(opentype)
62        unslicer = self.open(opentype)
63        if unslicer:
64            if self.gettingKey:
65                if self.keyConstraint:
66                    unslicer.setConstraint(self.keyConstraint)
67            else:
68                if self.valueConstraint:
69                    unslicer.setConstraint(self.valueConstraint)
70        return unslicer
71
72    def update(self, value, key):
73        # this is run as a Deferred callback, hence the backwards arguments
74        self.d[key] = value
75
76    def receiveChild(self, obj, ready_deferred=None):
77        if ready_deferred:
78            self._ready_deferreds.append(ready_deferred)
79        if self.gettingKey:
80            self.receiveKey(obj)
81        else:
82            self.receiveValue(obj)
83        self.gettingKey = not self.gettingKey
84
85    def receiveKey(self, key):
86        # I don't think it is legal (in python) to use an incomplete object
87        # as a dictionary key, because you must have all the contents to
88        # hash it. Someone could fake up a token stream to hit this case,
89        # however: OPEN(dict), OPEN(tuple), OPEN(reference), 0, CLOSE, CLOSE,
90        # "value", CLOSE
91        if isinstance(key, Deferred):
92            raise BananaError("incomplete object as dictionary key")
93        try:
94            if key in self.d:
95                raise BananaError("duplicate key '%s'" % key)
96        except TypeError:
97            raise BananaError("unhashable key '%s'" % key)
98        self.key = key
99
100    def receiveValue(self, value):
101        if isinstance(value, Deferred):
102            value.addCallback(self.update, self.key)
103            value.addErrback(log.err)
104        self.d[self.key] = value # placeholder
105
106    def receiveClose(self):
107        ready_deferred = None
108        if self._ready_deferreds:
109            ready_deferred = AsyncAND(self._ready_deferreds)
110        return self.d, ready_deferred
111
112    def describe(self):
113        if self.gettingKey:
114            return "{}"
115        else:
116            return "{}[%s]" % self.key
117
118
119class OrderedDictSlicer(DictSlicer):
120    slices = dict
121    def sliceBody(self, streamable, banana):
122        keys = list(self.obj.keys())
123        keys.sort()
124        for key in keys:
125            value = self.obj[key]
126            yield key
127            yield value
128
129
130class DictConstraint(OpenerConstraint):
131    opentypes = [("dict",)]
132    name = "DictConstraint"
133
134    def __init__(self, keyConstraint, valueConstraint, maxKeys=None):
135        self.keyConstraint = IConstraint(keyConstraint)
136        self.valueConstraint = IConstraint(valueConstraint)
137        self.maxKeys = maxKeys
138    def checkObject(self, obj, inbound):
139        if not isinstance(obj, dict):
140            raise Violation("'%s' (%s) is not a Dictionary" % (obj, type(obj)))
141        if self.maxKeys != None and len(obj) > self.maxKeys:
142            raise Violation("Dict keys=%d > maxKeys=%d" % (len(obj), self.maxKeys))
143        for key, value in obj.items():
144            self.keyConstraint.checkObject(key, inbound)
145            self.valueConstraint.checkObject(value, inbound)
146