1# -*- test-case-name: foolscap.test.test_banana -*- 2 3from twisted.python import log 4from twisted.internet.defer import Deferred 5from foolscap.tokens import Violation, BananaError 6from foolscap.slicer import BaseSlicer, BaseUnslicer 7from foolscap.constraint import OpenerConstraint, Any, IConstraint 8from foolscap.util import AsyncAND 9 10class DictSlicer(BaseSlicer): 11 opentype = ('dict',) 12 trackReferences = True 13 slices = None 14 def sliceBody(self, streamable, banana): 15 for key,value in list(self.obj.items()): 16 yield key 17 yield value 18 19class DictUnslicer(BaseUnslicer): 20 opentype = ('dict',) 21 22 gettingKey = True 23 keyConstraint = None 24 valueConstraint = None 25 maxKeys = None 26 27 def setConstraint(self, constraint): 28 if isinstance(constraint, Any): 29 return 30 assert isinstance(constraint, DictConstraint) 31 self.keyConstraint = constraint.keyConstraint 32 self.valueConstraint = constraint.valueConstraint 33 self.maxKeys = constraint.maxKeys 34 35 def start(self, count): 36 self.d = {} 37 self.protocol.setObject(count, self.d) 38 self.key = None 39 self._ready_deferreds = [] 40 41 def checkToken(self, typebyte, size): 42 if self.maxKeys != None: 43 if len(self.d) >= self.maxKeys: 44 raise Violation("the dict is full") 45 if self.gettingKey: 46 if self.keyConstraint: 47 self.keyConstraint.checkToken(typebyte, size) 48 else: 49 if self.valueConstraint: 50 self.valueConstraint.checkToken(typebyte, size) 51 52 def doOpen(self, opentype): 53 if self.maxKeys != None: 54 if len(self.d) >= self.maxKeys: 55 raise Violation("the dict is full") 56 if self.gettingKey: 57 if self.keyConstraint: 58 self.keyConstraint.checkOpentype(opentype) 59 else: 60 if self.valueConstraint: 61 self.valueConstraint.checkOpentype(opentype) 62 unslicer = self.open(opentype) 63 if unslicer: 64 if self.gettingKey: 65 if self.keyConstraint: 66 unslicer.setConstraint(self.keyConstraint) 67 else: 68 if self.valueConstraint: 69 unslicer.setConstraint(self.valueConstraint) 70 return unslicer 71 72 def update(self, value, key): 73 # this is run as a Deferred callback, hence the backwards arguments 74 self.d[key] = value 75 76 def receiveChild(self, obj, ready_deferred=None): 77 if ready_deferred: 78 self._ready_deferreds.append(ready_deferred) 79 if self.gettingKey: 80 self.receiveKey(obj) 81 else: 82 self.receiveValue(obj) 83 self.gettingKey = not self.gettingKey 84 85 def receiveKey(self, key): 86 # I don't think it is legal (in python) to use an incomplete object 87 # as a dictionary key, because you must have all the contents to 88 # hash it. Someone could fake up a token stream to hit this case, 89 # however: OPEN(dict), OPEN(tuple), OPEN(reference), 0, CLOSE, CLOSE, 90 # "value", CLOSE 91 if isinstance(key, Deferred): 92 raise BananaError("incomplete object as dictionary key") 93 try: 94 if key in self.d: 95 raise BananaError("duplicate key '%s'" % key) 96 except TypeError: 97 raise BananaError("unhashable key '%s'" % key) 98 self.key = key 99 100 def receiveValue(self, value): 101 if isinstance(value, Deferred): 102 value.addCallback(self.update, self.key) 103 value.addErrback(log.err) 104 self.d[self.key] = value # placeholder 105 106 def receiveClose(self): 107 ready_deferred = None 108 if self._ready_deferreds: 109 ready_deferred = AsyncAND(self._ready_deferreds) 110 return self.d, ready_deferred 111 112 def describe(self): 113 if self.gettingKey: 114 return "{}" 115 else: 116 return "{}[%s]" % self.key 117 118 119class OrderedDictSlicer(DictSlicer): 120 slices = dict 121 def sliceBody(self, streamable, banana): 122 keys = list(self.obj.keys()) 123 keys.sort() 124 for key in keys: 125 value = self.obj[key] 126 yield key 127 yield value 128 129 130class DictConstraint(OpenerConstraint): 131 opentypes = [("dict",)] 132 name = "DictConstraint" 133 134 def __init__(self, keyConstraint, valueConstraint, maxKeys=None): 135 self.keyConstraint = IConstraint(keyConstraint) 136 self.valueConstraint = IConstraint(valueConstraint) 137 self.maxKeys = maxKeys 138 def checkObject(self, obj, inbound): 139 if not isinstance(obj, dict): 140 raise Violation("'%s' (%s) is not a Dictionary" % (obj, type(obj))) 141 if self.maxKeys != None and len(obj) > self.maxKeys: 142 raise Violation("Dict keys=%d > maxKeys=%d" % (len(obj), self.maxKeys)) 143 for key, value in obj.items(): 144 self.keyConstraint.checkObject(key, inbound) 145 self.valueConstraint.checkObject(value, inbound) 146