1import six
2from pyrsistent._checked_types import (InvariantException, CheckedType, _restore_pickle, store_invariants)
3from pyrsistent._field_common import (
4    set_fields, check_type, is_field_ignore_extra_complaint, PFIELD_NO_INITIAL, serialize, check_global_invariants
5)
6from pyrsistent._transformations import transform
7
8
9def _is_pclass(bases):
10    return len(bases) == 1 and bases[0] == CheckedType
11
12
13class PClassMeta(type):
14    def __new__(mcs, name, bases, dct):
15        set_fields(dct, bases, name='_pclass_fields')
16        store_invariants(dct, bases, '_pclass_invariants', '__invariant__')
17        dct['__slots__'] = ('_pclass_frozen',) + tuple(key for key in dct['_pclass_fields'])
18
19        # There must only be one __weakref__ entry in the inheritance hierarchy,
20        # lets put it on the top level class.
21        if _is_pclass(bases):
22            dct['__slots__'] += ('__weakref__',)
23
24        return super(PClassMeta, mcs).__new__(mcs, name, bases, dct)
25
26_MISSING_VALUE = object()
27
28
29def _check_and_set_attr(cls, field, name, value, result, invariant_errors):
30    check_type(cls, field, name, value)
31    is_ok, error_code = field.invariant(value)
32    if not is_ok:
33        invariant_errors.append(error_code)
34    else:
35        setattr(result, name, value)
36
37
38@six.add_metaclass(PClassMeta)
39class PClass(CheckedType):
40    """
41    A PClass is a python class with a fixed set of specified fields. PClasses are declared as python classes inheriting
42    from PClass. It is defined the same way that PRecords are and behaves like a PRecord in all aspects except that it
43    is not a PMap and hence not a collection but rather a plain Python object.
44
45
46    More documentation and examples of PClass usage is available at https://github.com/tobgu/pyrsistent
47    """
48    def __new__(cls, **kwargs):    # Support *args?
49        result = super(PClass, cls).__new__(cls)
50        factory_fields = kwargs.pop('_factory_fields', None)
51        ignore_extra = kwargs.pop('ignore_extra', None)
52        missing_fields = []
53        invariant_errors = []
54        for name, field in cls._pclass_fields.items():
55            if name in kwargs:
56                if factory_fields is None or name in factory_fields:
57                    if is_field_ignore_extra_complaint(PClass, field, ignore_extra):
58                        value = field.factory(kwargs[name], ignore_extra=ignore_extra)
59                    else:
60                        value = field.factory(kwargs[name])
61                else:
62                    value = kwargs[name]
63                _check_and_set_attr(cls, field, name, value, result, invariant_errors)
64                del kwargs[name]
65            elif field.initial is not PFIELD_NO_INITIAL:
66                initial = field.initial() if callable(field.initial) else field.initial
67                _check_and_set_attr(
68                    cls, field, name, initial, result, invariant_errors)
69            elif field.mandatory:
70                missing_fields.append('{0}.{1}'.format(cls.__name__, name))
71
72        if invariant_errors or missing_fields:
73            raise InvariantException(tuple(invariant_errors), tuple(missing_fields), 'Field invariant failed')
74
75        if kwargs:
76            raise AttributeError("'{0}' are not among the specified fields for {1}".format(
77                ', '.join(kwargs), cls.__name__))
78
79        check_global_invariants(result, cls._pclass_invariants)
80
81        result._pclass_frozen = True
82        return result
83
84    def set(self, *args, **kwargs):
85        """
86        Set a field in the instance. Returns a new instance with the updated value. The original instance remains
87        unmodified. Accepts key-value pairs or single string representing the field name and a value.
88
89        >>> from pyrsistent import PClass, field
90        >>> class AClass(PClass):
91        ...     x = field()
92        ...
93        >>> a = AClass(x=1)
94        >>> a2 = a.set(x=2)
95        >>> a3 = a.set('x', 3)
96        >>> a
97        AClass(x=1)
98        >>> a2
99        AClass(x=2)
100        >>> a3
101        AClass(x=3)
102        """
103        if args:
104            kwargs[args[0]] = args[1]
105
106        factory_fields = set(kwargs)
107
108        for key in self._pclass_fields:
109            if key not in kwargs:
110                value = getattr(self, key, _MISSING_VALUE)
111                if value is not _MISSING_VALUE:
112                    kwargs[key] = value
113
114        return self.__class__(_factory_fields=factory_fields, **kwargs)
115
116    @classmethod
117    def create(cls, kwargs, _factory_fields=None, ignore_extra=False):
118        """
119        Factory method. Will create a new PClass of the current type and assign the values
120        specified in kwargs.
121
122        :param ignore_extra: A boolean which when set to True will ignore any keys which appear in kwargs that are not
123                             in the set of fields on the PClass.
124        """
125        if isinstance(kwargs, cls):
126            return kwargs
127
128        if ignore_extra:
129            kwargs = {k: kwargs[k] for k in cls._pclass_fields if k in kwargs}
130
131        return cls(_factory_fields=_factory_fields, ignore_extra=ignore_extra, **kwargs)
132
133    def serialize(self, format=None):
134        """
135        Serialize the current PClass using custom serializer functions for fields where
136        such have been supplied.
137        """
138        result = {}
139        for name in self._pclass_fields:
140            value = getattr(self, name, _MISSING_VALUE)
141            if value is not _MISSING_VALUE:
142                result[name] = serialize(self._pclass_fields[name].serializer, format, value)
143
144        return result
145
146    def transform(self, *transformations):
147        """
148        Apply transformations to the currency PClass. For more details on transformations see
149        the documentation for PMap. Transformations on PClasses do not support key matching
150        since the PClass is not a collection. Apart from that the transformations available
151        for other persistent types work as expected.
152        """
153        return transform(self, transformations)
154
155    def __eq__(self, other):
156        if isinstance(other, self.__class__):
157            for name in self._pclass_fields:
158                if getattr(self, name, _MISSING_VALUE) != getattr(other, name, _MISSING_VALUE):
159                    return False
160
161            return True
162
163        return NotImplemented
164
165    def __ne__(self, other):
166        return not self == other
167
168    def __hash__(self):
169        # May want to optimize this by caching the hash somehow
170        return hash(tuple((key, getattr(self, key, _MISSING_VALUE)) for key in self._pclass_fields))
171
172    def __setattr__(self, key, value):
173        if getattr(self, '_pclass_frozen', False):
174            raise AttributeError("Can't set attribute, key={0}, value={1}".format(key, value))
175
176        super(PClass, self).__setattr__(key, value)
177
178    def __delattr__(self, key):
179            raise AttributeError("Can't delete attribute, key={0}, use remove()".format(key))
180
181    def _to_dict(self):
182        result = {}
183        for key in self._pclass_fields:
184            value = getattr(self, key, _MISSING_VALUE)
185            if value is not _MISSING_VALUE:
186                result[key] = value
187
188        return result
189
190    def __repr__(self):
191        return "{0}({1})".format(self.__class__.__name__,
192                                 ', '.join('{0}={1}'.format(k, repr(v)) for k, v in self._to_dict().items()))
193
194    def __reduce__(self):
195        # Pickling support
196        data = dict((key, getattr(self, key)) for key in self._pclass_fields if hasattr(self, key))
197        return _restore_pickle, (self.__class__, data,)
198
199    def evolver(self):
200        """
201        Returns an evolver for this object.
202        """
203        return _PClassEvolver(self, self._to_dict())
204
205    def remove(self, name):
206        """
207        Remove attribute given by name from the current instance. Raises AttributeError if the
208        attribute doesn't exist.
209        """
210        evolver = self.evolver()
211        del evolver[name]
212        return evolver.persistent()
213
214
215class _PClassEvolver(object):
216    __slots__ = ('_pclass_evolver_original', '_pclass_evolver_data', '_pclass_evolver_data_is_dirty', '_factory_fields')
217
218    def __init__(self, original, initial_dict):
219        self._pclass_evolver_original = original
220        self._pclass_evolver_data = initial_dict
221        self._pclass_evolver_data_is_dirty = False
222        self._factory_fields = set()
223
224    def __getitem__(self, item):
225        return self._pclass_evolver_data[item]
226
227    def set(self, key, value):
228        if self._pclass_evolver_data.get(key, _MISSING_VALUE) is not value:
229            self._pclass_evolver_data[key] = value
230            self._factory_fields.add(key)
231            self._pclass_evolver_data_is_dirty = True
232
233        return self
234
235    def __setitem__(self, key, value):
236        self.set(key, value)
237
238    def remove(self, item):
239        if item in self._pclass_evolver_data:
240            del self._pclass_evolver_data[item]
241            self._factory_fields.discard(item)
242            self._pclass_evolver_data_is_dirty = True
243            return self
244
245        raise AttributeError(item)
246
247    def __delitem__(self, item):
248        self.remove(item)
249
250    def persistent(self):
251        if self._pclass_evolver_data_is_dirty:
252            return self._pclass_evolver_original.__class__(_factory_fields=self._factory_fields,
253                                                           **self._pclass_evolver_data)
254
255        return self._pclass_evolver_original
256
257    def __setattr__(self, key, value):
258        if key not in self.__slots__:
259            self.set(key, value)
260        else:
261            super(_PClassEvolver, self).__setattr__(key, value)
262
263    def __getattr__(self, item):
264        return self[item]
265