1from ._compat import Iterable
2import six
3
4from pyrsistent._compat import Enum, string_types
5from pyrsistent._pmap import PMap, pmap
6from pyrsistent._pset import PSet, pset
7from pyrsistent._pvector import PythonPVector, python_pvector
8
9
10class CheckedType(object):
11    """
12    Marker class to enable creation and serialization of checked object graphs.
13    """
14    __slots__ = ()
15
16    @classmethod
17    def create(cls, source_data, _factory_fields=None):
18        raise NotImplementedError()
19
20    def serialize(self, format=None):
21        raise NotImplementedError()
22
23
24def _restore_pickle(cls, data):
25    return cls.create(data, _factory_fields=set())
26
27
28class InvariantException(Exception):
29    """
30    Exception raised from a :py:class:`CheckedType` when invariant tests fail or when a mandatory
31    field is missing.
32
33    Contains two fields of interest:
34    invariant_errors, a tuple of error data for the failing invariants
35    missing_fields, a tuple of strings specifying the missing names
36    """
37
38    def __init__(self, error_codes=(), missing_fields=(), *args, **kwargs):
39        self.invariant_errors = tuple(e() if callable(e) else e for e in error_codes)
40        self.missing_fields = missing_fields
41        super(InvariantException, self).__init__(*args, **kwargs)
42
43    def __str__(self):
44        return super(InvariantException, self).__str__() + \
45            ", invariant_errors=[{invariant_errors}], missing_fields=[{missing_fields}]".format(
46            invariant_errors=', '.join(str(e) for e in self.invariant_errors),
47            missing_fields=', '.join(self.missing_fields))
48
49
50_preserved_iterable_types = (
51        Enum,
52)
53"""Some types are themselves iterable, but we want to use the type itself and
54not its members for the type specification. This defines a set of such types
55that we explicitly preserve.
56
57Note that strings are not such types because the string inputs we pass in are
58values, not types.
59"""
60
61
62def maybe_parse_user_type(t):
63    """Try to coerce a user-supplied type directive into a list of types.
64
65    This function should be used in all places where a user specifies a type,
66    for consistency.
67
68    The policy for what defines valid user input should be clear from the implementation.
69    """
70    is_type = isinstance(t, type)
71    is_preserved = isinstance(t, type) and issubclass(t, _preserved_iterable_types)
72    is_string = isinstance(t, string_types)
73    is_iterable = isinstance(t, Iterable)
74
75    if is_preserved:
76        return [t]
77    elif is_string:
78        return [t]
79    elif is_type and not is_iterable:
80        return [t]
81    elif is_iterable:
82        # Recur to validate contained types as well.
83        ts = t
84        return tuple(e for t in ts for e in maybe_parse_user_type(t))
85    else:
86        # If this raises because `t` cannot be formatted, so be it.
87        raise TypeError(
88            'Type specifications must be types or strings. Input: {}'.format(t)
89        )
90
91
92def maybe_parse_many_user_types(ts):
93    # Just a different name to communicate that you're parsing multiple user
94    # inputs. `maybe_parse_user_type` handles the iterable case anyway.
95    return maybe_parse_user_type(ts)
96
97
98def _store_types(dct, bases, destination_name, source_name):
99    maybe_types = maybe_parse_many_user_types([
100        d[source_name]
101        for d in ([dct] + [b.__dict__ for b in bases]) if source_name in d
102    ])
103
104    dct[destination_name] = maybe_types
105
106
107def _merge_invariant_results(result):
108    verdict = True
109    data = []
110    for verd, dat in result:
111        if not verd:
112            verdict = False
113            data.append(dat)
114
115    return verdict, tuple(data)
116
117
118def wrap_invariant(invariant):
119    # Invariant functions may return the outcome of several tests
120    # In those cases the results have to be merged before being passed
121    # back to the client.
122    def f(*args, **kwargs):
123        result = invariant(*args, **kwargs)
124        if isinstance(result[0], bool):
125            return result
126
127        return _merge_invariant_results(result)
128
129    return f
130
131
132def _all_dicts(bases, seen=None):
133    """
134    Yield each class in ``bases`` and each of their base classes.
135    """
136    if seen is None:
137        seen = set()
138    for cls in bases:
139        if cls in seen:
140            continue
141        seen.add(cls)
142        yield cls.__dict__
143        for b in _all_dicts(cls.__bases__, seen):
144            yield b
145
146
147def store_invariants(dct, bases, destination_name, source_name):
148    # Invariants are inherited
149    invariants = []
150    for ns in [dct] + list(_all_dicts(bases)):
151        try:
152            invariant = ns[source_name]
153        except KeyError:
154            continue
155        invariants.append(invariant)
156
157    if not all(callable(invariant) for invariant in invariants):
158        raise TypeError('Invariants must be callable')
159    dct[destination_name] = tuple(wrap_invariant(inv) for inv in invariants)
160
161
162class _CheckedTypeMeta(type):
163    def __new__(mcs, name, bases, dct):
164        _store_types(dct, bases, '_checked_types', '__type__')
165        store_invariants(dct, bases, '_checked_invariants', '__invariant__')
166
167        def default_serializer(self, _, value):
168            if isinstance(value, CheckedType):
169                return value.serialize()
170            return value
171
172        dct.setdefault('__serializer__', default_serializer)
173
174        dct['__slots__'] = ()
175
176        return super(_CheckedTypeMeta, mcs).__new__(mcs, name, bases, dct)
177
178
179class CheckedTypeError(TypeError):
180    def __init__(self, source_class, expected_types, actual_type, actual_value, *args, **kwargs):
181        super(CheckedTypeError, self).__init__(*args, **kwargs)
182        self.source_class = source_class
183        self.expected_types = expected_types
184        self.actual_type = actual_type
185        self.actual_value = actual_value
186
187
188class CheckedKeyTypeError(CheckedTypeError):
189    """
190    Raised when trying to set a value using a key with a type that doesn't match the declared type.
191
192    Attributes:
193    source_class -- The class of the collection
194    expected_types  -- Allowed types
195    actual_type -- The non matching type
196    actual_value -- Value of the variable with the non matching type
197    """
198    pass
199
200
201class CheckedValueTypeError(CheckedTypeError):
202    """
203    Raised when trying to set a value using a key with a type that doesn't match the declared type.
204
205    Attributes:
206    source_class -- The class of the collection
207    expected_types  -- Allowed types
208    actual_type -- The non matching type
209    actual_value -- Value of the variable with the non matching type
210    """
211    pass
212
213
214def _get_class(type_name):
215    module_name, class_name = type_name.rsplit('.', 1)
216    module = __import__(module_name, fromlist=[class_name])
217    return getattr(module, class_name)
218
219
220def get_type(typ):
221    if isinstance(typ, type):
222        return typ
223
224    return _get_class(typ)
225
226
227def get_types(typs):
228    return [get_type(typ) for typ in typs]
229
230
231def _check_types(it, expected_types, source_class, exception_type=CheckedValueTypeError):
232    if expected_types:
233        for e in it:
234            if not any(isinstance(e, get_type(t)) for t in expected_types):
235                actual_type = type(e)
236                msg = "Type {source_class} can only be used with {expected_types}, not {actual_type}".format(
237                    source_class=source_class.__name__,
238                    expected_types=tuple(get_type(et).__name__ for et in expected_types),
239                    actual_type=actual_type.__name__)
240                raise exception_type(source_class, expected_types, actual_type, e, msg)
241
242
243def _invariant_errors(elem, invariants):
244    return [data for valid, data in (invariant(elem) for invariant in invariants) if not valid]
245
246
247def _invariant_errors_iterable(it, invariants):
248    return sum([_invariant_errors(elem, invariants) for elem in it], [])
249
250
251def optional(*typs):
252    """ Convenience function to specify that a value may be of any of the types in type 'typs' or None """
253    return tuple(typs) + (type(None),)
254
255
256def _checked_type_create(cls, source_data, _factory_fields=None, ignore_extra=False):
257    if isinstance(source_data, cls):
258        return source_data
259
260    # Recursively apply create methods of checked types if the types of the supplied data
261    # does not match any of the valid types.
262    types = get_types(cls._checked_types)
263    checked_type = next((t for t in types if issubclass(t, CheckedType)), None)
264    if checked_type:
265        return cls([checked_type.create(data, ignore_extra=ignore_extra)
266                    if not any(isinstance(data, t) for t in types) else data
267                    for data in source_data])
268
269    return cls(source_data)
270
271@six.add_metaclass(_CheckedTypeMeta)
272class CheckedPVector(PythonPVector, CheckedType):
273    """
274    A CheckedPVector is a PVector which allows specifying type and invariant checks.
275
276    >>> class Positives(CheckedPVector):
277    ...     __type__ = (long, int)
278    ...     __invariant__ = lambda n: (n >= 0, 'Negative')
279    ...
280    >>> Positives([1, 2, 3])
281    Positives([1, 2, 3])
282    """
283
284    __slots__ = ()
285
286    def __new__(cls, initial=()):
287        if type(initial) == PythonPVector:
288            return super(CheckedPVector, cls).__new__(cls, initial._count, initial._shift, initial._root, initial._tail)
289
290        return CheckedPVector.Evolver(cls, python_pvector()).extend(initial).persistent()
291
292    def set(self, key, value):
293        return self.evolver().set(key, value).persistent()
294
295    def append(self, val):
296        return self.evolver().append(val).persistent()
297
298    def extend(self, it):
299        return self.evolver().extend(it).persistent()
300
301    create = classmethod(_checked_type_create)
302
303    def serialize(self, format=None):
304        serializer = self.__serializer__
305        return list(serializer(format, v) for v in self)
306
307    def __reduce__(self):
308        # Pickling support
309        return _restore_pickle, (self.__class__, list(self),)
310
311    class Evolver(PythonPVector.Evolver):
312        __slots__ = ('_destination_class', '_invariant_errors')
313
314        def __init__(self, destination_class, vector):
315            super(CheckedPVector.Evolver, self).__init__(vector)
316            self._destination_class = destination_class
317            self._invariant_errors = []
318
319        def _check(self, it):
320            _check_types(it, self._destination_class._checked_types, self._destination_class)
321            error_data = _invariant_errors_iterable(it, self._destination_class._checked_invariants)
322            self._invariant_errors.extend(error_data)
323
324        def __setitem__(self, key, value):
325            self._check([value])
326            return super(CheckedPVector.Evolver, self).__setitem__(key, value)
327
328        def append(self, elem):
329            self._check([elem])
330            return super(CheckedPVector.Evolver, self).append(elem)
331
332        def extend(self, it):
333            it = list(it)
334            self._check(it)
335            return super(CheckedPVector.Evolver, self).extend(it)
336
337        def persistent(self):
338            if self._invariant_errors:
339                raise InvariantException(error_codes=self._invariant_errors)
340
341            result = self._orig_pvector
342            if self.is_dirty() or (self._destination_class != type(self._orig_pvector)):
343                pv = super(CheckedPVector.Evolver, self).persistent().extend(self._extra_tail)
344                result = self._destination_class(pv)
345                self._reset(result)
346
347            return result
348
349    def __repr__(self):
350        return self.__class__.__name__ + "({0})".format(self.tolist())
351
352    __str__ = __repr__
353
354    def evolver(self):
355        return CheckedPVector.Evolver(self.__class__, self)
356
357
358@six.add_metaclass(_CheckedTypeMeta)
359class CheckedPSet(PSet, CheckedType):
360    """
361    A CheckedPSet is a PSet which allows specifying type and invariant checks.
362
363    >>> class Positives(CheckedPSet):
364    ...     __type__ = (long, int)
365    ...     __invariant__ = lambda n: (n >= 0, 'Negative')
366    ...
367    >>> Positives([1, 2, 3])
368    Positives([1, 2, 3])
369    """
370
371    __slots__ = ()
372
373    def __new__(cls, initial=()):
374        if type(initial) is PMap:
375            return super(CheckedPSet, cls).__new__(cls, initial)
376
377        evolver = CheckedPSet.Evolver(cls, pset())
378        for e in initial:
379            evolver.add(e)
380
381        return evolver.persistent()
382
383    def __repr__(self):
384        return self.__class__.__name__ + super(CheckedPSet, self).__repr__()[4:]
385
386    def __str__(self):
387        return self.__repr__()
388
389    def serialize(self, format=None):
390        serializer = self.__serializer__
391        return set(serializer(format, v) for v in self)
392
393    create = classmethod(_checked_type_create)
394
395    def __reduce__(self):
396        # Pickling support
397        return _restore_pickle, (self.__class__, list(self),)
398
399    def evolver(self):
400        return CheckedPSet.Evolver(self.__class__, self)
401
402    class Evolver(PSet._Evolver):
403        __slots__ = ('_destination_class', '_invariant_errors')
404
405        def __init__(self, destination_class, original_set):
406            super(CheckedPSet.Evolver, self).__init__(original_set)
407            self._destination_class = destination_class
408            self._invariant_errors = []
409
410        def _check(self, it):
411            _check_types(it, self._destination_class._checked_types, self._destination_class)
412            error_data = _invariant_errors_iterable(it, self._destination_class._checked_invariants)
413            self._invariant_errors.extend(error_data)
414
415        def add(self, element):
416            self._check([element])
417            self._pmap_evolver[element] = True
418            return self
419
420        def persistent(self):
421            if self._invariant_errors:
422                raise InvariantException(error_codes=self._invariant_errors)
423
424            if self.is_dirty() or self._destination_class != type(self._original_pset):
425                return self._destination_class(self._pmap_evolver.persistent())
426
427            return self._original_pset
428
429
430class _CheckedMapTypeMeta(type):
431    def __new__(mcs, name, bases, dct):
432        _store_types(dct, bases, '_checked_key_types', '__key_type__')
433        _store_types(dct, bases, '_checked_value_types', '__value_type__')
434        store_invariants(dct, bases, '_checked_invariants', '__invariant__')
435
436        def default_serializer(self, _, key, value):
437            sk = key
438            if isinstance(key, CheckedType):
439                sk = key.serialize()
440
441            sv = value
442            if isinstance(value, CheckedType):
443                sv = value.serialize()
444
445            return sk, sv
446
447        dct.setdefault('__serializer__', default_serializer)
448
449        dct['__slots__'] = ()
450
451        return super(_CheckedMapTypeMeta, mcs).__new__(mcs, name, bases, dct)
452
453# Marker object
454_UNDEFINED_CHECKED_PMAP_SIZE = object()
455
456
457@six.add_metaclass(_CheckedMapTypeMeta)
458class CheckedPMap(PMap, CheckedType):
459    """
460    A CheckedPMap is a PMap which allows specifying type and invariant checks.
461
462    >>> class IntToFloatMap(CheckedPMap):
463    ...     __key_type__ = int
464    ...     __value_type__ = float
465    ...     __invariant__ = lambda k, v: (int(v) == k, 'Invalid mapping')
466    ...
467    >>> IntToFloatMap({1: 1.5, 2: 2.25})
468    IntToFloatMap({1: 1.5, 2: 2.25})
469    """
470
471    __slots__ = ()
472
473    def __new__(cls, initial={}, size=_UNDEFINED_CHECKED_PMAP_SIZE):
474        if size is not _UNDEFINED_CHECKED_PMAP_SIZE:
475            return super(CheckedPMap, cls).__new__(cls, size, initial)
476
477        evolver = CheckedPMap.Evolver(cls, pmap())
478        for k, v in initial.items():
479            evolver.set(k, v)
480
481        return evolver.persistent()
482
483    def evolver(self):
484        return CheckedPMap.Evolver(self.__class__, self)
485
486    def __repr__(self):
487        return self.__class__.__name__ + "({0})".format(str(dict(self)))
488
489    __str__ = __repr__
490
491    def serialize(self, format=None):
492        serializer = self.__serializer__
493        return dict(serializer(format, k, v) for k, v in self.items())
494
495    @classmethod
496    def create(cls, source_data, _factory_fields=None):
497        if isinstance(source_data, cls):
498            return source_data
499
500        # Recursively apply create methods of checked types if the types of the supplied data
501        # does not match any of the valid types.
502        key_types = get_types(cls._checked_key_types)
503        checked_key_type = next((t for t in key_types if issubclass(t, CheckedType)), None)
504        value_types = get_types(cls._checked_value_types)
505        checked_value_type = next((t for t in value_types if issubclass(t, CheckedType)), None)
506
507        if checked_key_type or checked_value_type:
508            return cls(dict((checked_key_type.create(key) if checked_key_type and not any(isinstance(key, t) for t in key_types) else key,
509                             checked_value_type.create(value) if checked_value_type and not any(isinstance(value, t) for t in value_types) else value)
510                            for key, value in source_data.items()))
511
512        return cls(source_data)
513
514    def __reduce__(self):
515        # Pickling support
516        return _restore_pickle, (self.__class__, dict(self),)
517
518    class Evolver(PMap._Evolver):
519        __slots__ = ('_destination_class', '_invariant_errors')
520
521        def __init__(self, destination_class, original_map):
522            super(CheckedPMap.Evolver, self).__init__(original_map)
523            self._destination_class = destination_class
524            self._invariant_errors = []
525
526        def set(self, key, value):
527            _check_types([key], self._destination_class._checked_key_types, self._destination_class, CheckedKeyTypeError)
528            _check_types([value], self._destination_class._checked_value_types, self._destination_class)
529            self._invariant_errors.extend(data for valid, data in (invariant(key, value)
530                                                                   for invariant in self._destination_class._checked_invariants)
531                                          if not valid)
532
533            return super(CheckedPMap.Evolver, self).set(key, value)
534
535        def persistent(self):
536            if self._invariant_errors:
537                raise InvariantException(error_codes=self._invariant_errors)
538
539            if self.is_dirty() or type(self._original_pmap) != self._destination_class:
540                return self._destination_class(self._buckets_evolver.persistent(), self._size)
541
542            return self._original_pmap
543