1from ._compat import Iterable 2import six 3 4from pyrsistent._compat import Enum, string_types 5from pyrsistent._pmap import PMap, pmap 6from pyrsistent._pset import PSet, pset 7from pyrsistent._pvector import PythonPVector, python_pvector 8 9 10class CheckedType(object): 11 """ 12 Marker class to enable creation and serialization of checked object graphs. 13 """ 14 __slots__ = () 15 16 @classmethod 17 def create(cls, source_data, _factory_fields=None): 18 raise NotImplementedError() 19 20 def serialize(self, format=None): 21 raise NotImplementedError() 22 23 24def _restore_pickle(cls, data): 25 return cls.create(data, _factory_fields=set()) 26 27 28class InvariantException(Exception): 29 """ 30 Exception raised from a :py:class:`CheckedType` when invariant tests fail or when a mandatory 31 field is missing. 32 33 Contains two fields of interest: 34 invariant_errors, a tuple of error data for the failing invariants 35 missing_fields, a tuple of strings specifying the missing names 36 """ 37 38 def __init__(self, error_codes=(), missing_fields=(), *args, **kwargs): 39 self.invariant_errors = tuple(e() if callable(e) else e for e in error_codes) 40 self.missing_fields = missing_fields 41 super(InvariantException, self).__init__(*args, **kwargs) 42 43 def __str__(self): 44 return super(InvariantException, self).__str__() + \ 45 ", invariant_errors=[{invariant_errors}], missing_fields=[{missing_fields}]".format( 46 invariant_errors=', '.join(str(e) for e in self.invariant_errors), 47 missing_fields=', '.join(self.missing_fields)) 48 49 50_preserved_iterable_types = ( 51 Enum, 52) 53"""Some types are themselves iterable, but we want to use the type itself and 54not its members for the type specification. This defines a set of such types 55that we explicitly preserve. 56 57Note that strings are not such types because the string inputs we pass in are 58values, not types. 59""" 60 61 62def maybe_parse_user_type(t): 63 """Try to coerce a user-supplied type directive into a list of types. 64 65 This function should be used in all places where a user specifies a type, 66 for consistency. 67 68 The policy for what defines valid user input should be clear from the implementation. 69 """ 70 is_type = isinstance(t, type) 71 is_preserved = isinstance(t, type) and issubclass(t, _preserved_iterable_types) 72 is_string = isinstance(t, string_types) 73 is_iterable = isinstance(t, Iterable) 74 75 if is_preserved: 76 return [t] 77 elif is_string: 78 return [t] 79 elif is_type and not is_iterable: 80 return [t] 81 elif is_iterable: 82 # Recur to validate contained types as well. 83 ts = t 84 return tuple(e for t in ts for e in maybe_parse_user_type(t)) 85 else: 86 # If this raises because `t` cannot be formatted, so be it. 87 raise TypeError( 88 'Type specifications must be types or strings. Input: {}'.format(t) 89 ) 90 91 92def maybe_parse_many_user_types(ts): 93 # Just a different name to communicate that you're parsing multiple user 94 # inputs. `maybe_parse_user_type` handles the iterable case anyway. 95 return maybe_parse_user_type(ts) 96 97 98def _store_types(dct, bases, destination_name, source_name): 99 maybe_types = maybe_parse_many_user_types([ 100 d[source_name] 101 for d in ([dct] + [b.__dict__ for b in bases]) if source_name in d 102 ]) 103 104 dct[destination_name] = maybe_types 105 106 107def _merge_invariant_results(result): 108 verdict = True 109 data = [] 110 for verd, dat in result: 111 if not verd: 112 verdict = False 113 data.append(dat) 114 115 return verdict, tuple(data) 116 117 118def wrap_invariant(invariant): 119 # Invariant functions may return the outcome of several tests 120 # In those cases the results have to be merged before being passed 121 # back to the client. 122 def f(*args, **kwargs): 123 result = invariant(*args, **kwargs) 124 if isinstance(result[0], bool): 125 return result 126 127 return _merge_invariant_results(result) 128 129 return f 130 131 132def _all_dicts(bases, seen=None): 133 """ 134 Yield each class in ``bases`` and each of their base classes. 135 """ 136 if seen is None: 137 seen = set() 138 for cls in bases: 139 if cls in seen: 140 continue 141 seen.add(cls) 142 yield cls.__dict__ 143 for b in _all_dicts(cls.__bases__, seen): 144 yield b 145 146 147def store_invariants(dct, bases, destination_name, source_name): 148 # Invariants are inherited 149 invariants = [] 150 for ns in [dct] + list(_all_dicts(bases)): 151 try: 152 invariant = ns[source_name] 153 except KeyError: 154 continue 155 invariants.append(invariant) 156 157 if not all(callable(invariant) for invariant in invariants): 158 raise TypeError('Invariants must be callable') 159 dct[destination_name] = tuple(wrap_invariant(inv) for inv in invariants) 160 161 162class _CheckedTypeMeta(type): 163 def __new__(mcs, name, bases, dct): 164 _store_types(dct, bases, '_checked_types', '__type__') 165 store_invariants(dct, bases, '_checked_invariants', '__invariant__') 166 167 def default_serializer(self, _, value): 168 if isinstance(value, CheckedType): 169 return value.serialize() 170 return value 171 172 dct.setdefault('__serializer__', default_serializer) 173 174 dct['__slots__'] = () 175 176 return super(_CheckedTypeMeta, mcs).__new__(mcs, name, bases, dct) 177 178 179class CheckedTypeError(TypeError): 180 def __init__(self, source_class, expected_types, actual_type, actual_value, *args, **kwargs): 181 super(CheckedTypeError, self).__init__(*args, **kwargs) 182 self.source_class = source_class 183 self.expected_types = expected_types 184 self.actual_type = actual_type 185 self.actual_value = actual_value 186 187 188class CheckedKeyTypeError(CheckedTypeError): 189 """ 190 Raised when trying to set a value using a key with a type that doesn't match the declared type. 191 192 Attributes: 193 source_class -- The class of the collection 194 expected_types -- Allowed types 195 actual_type -- The non matching type 196 actual_value -- Value of the variable with the non matching type 197 """ 198 pass 199 200 201class CheckedValueTypeError(CheckedTypeError): 202 """ 203 Raised when trying to set a value using a key with a type that doesn't match the declared type. 204 205 Attributes: 206 source_class -- The class of the collection 207 expected_types -- Allowed types 208 actual_type -- The non matching type 209 actual_value -- Value of the variable with the non matching type 210 """ 211 pass 212 213 214def _get_class(type_name): 215 module_name, class_name = type_name.rsplit('.', 1) 216 module = __import__(module_name, fromlist=[class_name]) 217 return getattr(module, class_name) 218 219 220def get_type(typ): 221 if isinstance(typ, type): 222 return typ 223 224 return _get_class(typ) 225 226 227def get_types(typs): 228 return [get_type(typ) for typ in typs] 229 230 231def _check_types(it, expected_types, source_class, exception_type=CheckedValueTypeError): 232 if expected_types: 233 for e in it: 234 if not any(isinstance(e, get_type(t)) for t in expected_types): 235 actual_type = type(e) 236 msg = "Type {source_class} can only be used with {expected_types}, not {actual_type}".format( 237 source_class=source_class.__name__, 238 expected_types=tuple(get_type(et).__name__ for et in expected_types), 239 actual_type=actual_type.__name__) 240 raise exception_type(source_class, expected_types, actual_type, e, msg) 241 242 243def _invariant_errors(elem, invariants): 244 return [data for valid, data in (invariant(elem) for invariant in invariants) if not valid] 245 246 247def _invariant_errors_iterable(it, invariants): 248 return sum([_invariant_errors(elem, invariants) for elem in it], []) 249 250 251def optional(*typs): 252 """ Convenience function to specify that a value may be of any of the types in type 'typs' or None """ 253 return tuple(typs) + (type(None),) 254 255 256def _checked_type_create(cls, source_data, _factory_fields=None, ignore_extra=False): 257 if isinstance(source_data, cls): 258 return source_data 259 260 # Recursively apply create methods of checked types if the types of the supplied data 261 # does not match any of the valid types. 262 types = get_types(cls._checked_types) 263 checked_type = next((t for t in types if issubclass(t, CheckedType)), None) 264 if checked_type: 265 return cls([checked_type.create(data, ignore_extra=ignore_extra) 266 if not any(isinstance(data, t) for t in types) else data 267 for data in source_data]) 268 269 return cls(source_data) 270 271@six.add_metaclass(_CheckedTypeMeta) 272class CheckedPVector(PythonPVector, CheckedType): 273 """ 274 A CheckedPVector is a PVector which allows specifying type and invariant checks. 275 276 >>> class Positives(CheckedPVector): 277 ... __type__ = (long, int) 278 ... __invariant__ = lambda n: (n >= 0, 'Negative') 279 ... 280 >>> Positives([1, 2, 3]) 281 Positives([1, 2, 3]) 282 """ 283 284 __slots__ = () 285 286 def __new__(cls, initial=()): 287 if type(initial) == PythonPVector: 288 return super(CheckedPVector, cls).__new__(cls, initial._count, initial._shift, initial._root, initial._tail) 289 290 return CheckedPVector.Evolver(cls, python_pvector()).extend(initial).persistent() 291 292 def set(self, key, value): 293 return self.evolver().set(key, value).persistent() 294 295 def append(self, val): 296 return self.evolver().append(val).persistent() 297 298 def extend(self, it): 299 return self.evolver().extend(it).persistent() 300 301 create = classmethod(_checked_type_create) 302 303 def serialize(self, format=None): 304 serializer = self.__serializer__ 305 return list(serializer(format, v) for v in self) 306 307 def __reduce__(self): 308 # Pickling support 309 return _restore_pickle, (self.__class__, list(self),) 310 311 class Evolver(PythonPVector.Evolver): 312 __slots__ = ('_destination_class', '_invariant_errors') 313 314 def __init__(self, destination_class, vector): 315 super(CheckedPVector.Evolver, self).__init__(vector) 316 self._destination_class = destination_class 317 self._invariant_errors = [] 318 319 def _check(self, it): 320 _check_types(it, self._destination_class._checked_types, self._destination_class) 321 error_data = _invariant_errors_iterable(it, self._destination_class._checked_invariants) 322 self._invariant_errors.extend(error_data) 323 324 def __setitem__(self, key, value): 325 self._check([value]) 326 return super(CheckedPVector.Evolver, self).__setitem__(key, value) 327 328 def append(self, elem): 329 self._check([elem]) 330 return super(CheckedPVector.Evolver, self).append(elem) 331 332 def extend(self, it): 333 it = list(it) 334 self._check(it) 335 return super(CheckedPVector.Evolver, self).extend(it) 336 337 def persistent(self): 338 if self._invariant_errors: 339 raise InvariantException(error_codes=self._invariant_errors) 340 341 result = self._orig_pvector 342 if self.is_dirty() or (self._destination_class != type(self._orig_pvector)): 343 pv = super(CheckedPVector.Evolver, self).persistent().extend(self._extra_tail) 344 result = self._destination_class(pv) 345 self._reset(result) 346 347 return result 348 349 def __repr__(self): 350 return self.__class__.__name__ + "({0})".format(self.tolist()) 351 352 __str__ = __repr__ 353 354 def evolver(self): 355 return CheckedPVector.Evolver(self.__class__, self) 356 357 358@six.add_metaclass(_CheckedTypeMeta) 359class CheckedPSet(PSet, CheckedType): 360 """ 361 A CheckedPSet is a PSet which allows specifying type and invariant checks. 362 363 >>> class Positives(CheckedPSet): 364 ... __type__ = (long, int) 365 ... __invariant__ = lambda n: (n >= 0, 'Negative') 366 ... 367 >>> Positives([1, 2, 3]) 368 Positives([1, 2, 3]) 369 """ 370 371 __slots__ = () 372 373 def __new__(cls, initial=()): 374 if type(initial) is PMap: 375 return super(CheckedPSet, cls).__new__(cls, initial) 376 377 evolver = CheckedPSet.Evolver(cls, pset()) 378 for e in initial: 379 evolver.add(e) 380 381 return evolver.persistent() 382 383 def __repr__(self): 384 return self.__class__.__name__ + super(CheckedPSet, self).__repr__()[4:] 385 386 def __str__(self): 387 return self.__repr__() 388 389 def serialize(self, format=None): 390 serializer = self.__serializer__ 391 return set(serializer(format, v) for v in self) 392 393 create = classmethod(_checked_type_create) 394 395 def __reduce__(self): 396 # Pickling support 397 return _restore_pickle, (self.__class__, list(self),) 398 399 def evolver(self): 400 return CheckedPSet.Evolver(self.__class__, self) 401 402 class Evolver(PSet._Evolver): 403 __slots__ = ('_destination_class', '_invariant_errors') 404 405 def __init__(self, destination_class, original_set): 406 super(CheckedPSet.Evolver, self).__init__(original_set) 407 self._destination_class = destination_class 408 self._invariant_errors = [] 409 410 def _check(self, it): 411 _check_types(it, self._destination_class._checked_types, self._destination_class) 412 error_data = _invariant_errors_iterable(it, self._destination_class._checked_invariants) 413 self._invariant_errors.extend(error_data) 414 415 def add(self, element): 416 self._check([element]) 417 self._pmap_evolver[element] = True 418 return self 419 420 def persistent(self): 421 if self._invariant_errors: 422 raise InvariantException(error_codes=self._invariant_errors) 423 424 if self.is_dirty() or self._destination_class != type(self._original_pset): 425 return self._destination_class(self._pmap_evolver.persistent()) 426 427 return self._original_pset 428 429 430class _CheckedMapTypeMeta(type): 431 def __new__(mcs, name, bases, dct): 432 _store_types(dct, bases, '_checked_key_types', '__key_type__') 433 _store_types(dct, bases, '_checked_value_types', '__value_type__') 434 store_invariants(dct, bases, '_checked_invariants', '__invariant__') 435 436 def default_serializer(self, _, key, value): 437 sk = key 438 if isinstance(key, CheckedType): 439 sk = key.serialize() 440 441 sv = value 442 if isinstance(value, CheckedType): 443 sv = value.serialize() 444 445 return sk, sv 446 447 dct.setdefault('__serializer__', default_serializer) 448 449 dct['__slots__'] = () 450 451 return super(_CheckedMapTypeMeta, mcs).__new__(mcs, name, bases, dct) 452 453# Marker object 454_UNDEFINED_CHECKED_PMAP_SIZE = object() 455 456 457@six.add_metaclass(_CheckedMapTypeMeta) 458class CheckedPMap(PMap, CheckedType): 459 """ 460 A CheckedPMap is a PMap which allows specifying type and invariant checks. 461 462 >>> class IntToFloatMap(CheckedPMap): 463 ... __key_type__ = int 464 ... __value_type__ = float 465 ... __invariant__ = lambda k, v: (int(v) == k, 'Invalid mapping') 466 ... 467 >>> IntToFloatMap({1: 1.5, 2: 2.25}) 468 IntToFloatMap({1: 1.5, 2: 2.25}) 469 """ 470 471 __slots__ = () 472 473 def __new__(cls, initial={}, size=_UNDEFINED_CHECKED_PMAP_SIZE): 474 if size is not _UNDEFINED_CHECKED_PMAP_SIZE: 475 return super(CheckedPMap, cls).__new__(cls, size, initial) 476 477 evolver = CheckedPMap.Evolver(cls, pmap()) 478 for k, v in initial.items(): 479 evolver.set(k, v) 480 481 return evolver.persistent() 482 483 def evolver(self): 484 return CheckedPMap.Evolver(self.__class__, self) 485 486 def __repr__(self): 487 return self.__class__.__name__ + "({0})".format(str(dict(self))) 488 489 __str__ = __repr__ 490 491 def serialize(self, format=None): 492 serializer = self.__serializer__ 493 return dict(serializer(format, k, v) for k, v in self.items()) 494 495 @classmethod 496 def create(cls, source_data, _factory_fields=None): 497 if isinstance(source_data, cls): 498 return source_data 499 500 # Recursively apply create methods of checked types if the types of the supplied data 501 # does not match any of the valid types. 502 key_types = get_types(cls._checked_key_types) 503 checked_key_type = next((t for t in key_types if issubclass(t, CheckedType)), None) 504 value_types = get_types(cls._checked_value_types) 505 checked_value_type = next((t for t in value_types if issubclass(t, CheckedType)), None) 506 507 if checked_key_type or checked_value_type: 508 return cls(dict((checked_key_type.create(key) if checked_key_type and not any(isinstance(key, t) for t in key_types) else key, 509 checked_value_type.create(value) if checked_value_type and not any(isinstance(value, t) for t in value_types) else value) 510 for key, value in source_data.items())) 511 512 return cls(source_data) 513 514 def __reduce__(self): 515 # Pickling support 516 return _restore_pickle, (self.__class__, dict(self),) 517 518 class Evolver(PMap._Evolver): 519 __slots__ = ('_destination_class', '_invariant_errors') 520 521 def __init__(self, destination_class, original_map): 522 super(CheckedPMap.Evolver, self).__init__(original_map) 523 self._destination_class = destination_class 524 self._invariant_errors = [] 525 526 def set(self, key, value): 527 _check_types([key], self._destination_class._checked_key_types, self._destination_class, CheckedKeyTypeError) 528 _check_types([value], self._destination_class._checked_value_types, self._destination_class) 529 self._invariant_errors.extend(data for valid, data in (invariant(key, value) 530 for invariant in self._destination_class._checked_invariants) 531 if not valid) 532 533 return super(CheckedPMap.Evolver, self).set(key, value) 534 535 def persistent(self): 536 if self._invariant_errors: 537 raise InvariantException(error_codes=self._invariant_errors) 538 539 if self.is_dirty() or type(self._original_pmap) != self._destination_class: 540 return self._destination_class(self._buckets_evolver.persistent(), self._size) 541 542 return self._original_pmap 543