1import six
2from pyrsistent._checked_types import CheckedType, _restore_pickle, InvariantException, store_invariants
3from pyrsistent._field_common import (
4    set_fields, check_type, is_field_ignore_extra_complaint, PFIELD_NO_INITIAL, serialize, check_global_invariants
5)
6from pyrsistent._pmap import PMap, pmap
7
8
9class _PRecordMeta(type):
10    def __new__(mcs, name, bases, dct):
11        set_fields(dct, bases, name='_precord_fields')
12        store_invariants(dct, bases, '_precord_invariants', '__invariant__')
13
14        dct['_precord_mandatory_fields'] = \
15            set(name for name, field in dct['_precord_fields'].items() if field.mandatory)
16
17        dct['_precord_initial_values'] = \
18            dict((k, field.initial) for k, field in dct['_precord_fields'].items() if field.initial is not PFIELD_NO_INITIAL)
19
20
21        dct['__slots__'] = ()
22
23        return super(_PRecordMeta, mcs).__new__(mcs, name, bases, dct)
24
25
26@six.add_metaclass(_PRecordMeta)
27class PRecord(PMap, CheckedType):
28    """
29    A PRecord is a PMap with a fixed set of specified fields. Records are declared as python classes inheriting
30    from PRecord. Because it is a PMap it has full support for all Mapping methods such as iteration and element
31    access using subscript notation.
32
33    More documentation and examples of PRecord usage is available at https://github.com/tobgu/pyrsistent
34    """
35    def __new__(cls, **kwargs):
36        # Hack total! If these two special attributes exist that means we can create
37        # ourselves. Otherwise we need to go through the Evolver to create the structures
38        # for us.
39        if '_precord_size' in kwargs and '_precord_buckets' in kwargs:
40            return super(PRecord, cls).__new__(cls, kwargs['_precord_size'], kwargs['_precord_buckets'])
41
42        factory_fields = kwargs.pop('_factory_fields', None)
43        ignore_extra = kwargs.pop('_ignore_extra', False)
44
45        initial_values = kwargs
46        if cls._precord_initial_values:
47            initial_values = dict((k, v() if callable(v) else v)
48                                  for k, v in cls._precord_initial_values.items())
49            initial_values.update(kwargs)
50
51        e = _PRecordEvolver(cls, pmap(), _factory_fields=factory_fields, _ignore_extra=ignore_extra)
52        for k, v in initial_values.items():
53            e[k] = v
54
55        return e.persistent()
56
57    def set(self, *args, **kwargs):
58        """
59        Set a field in the record. This set function differs slightly from that in the PMap
60        class. First of all it accepts key-value pairs. Second it accepts multiple key-value
61        pairs to perform one, atomic, update of multiple fields.
62        """
63
64        # The PRecord set() can accept kwargs since all fields that have been declared are
65        # valid python identifiers. Also allow multiple fields to be set in one operation.
66        if args:
67            return super(PRecord, self).set(args[0], args[1])
68
69        return self.update(kwargs)
70
71    def evolver(self):
72        """
73        Returns an evolver of this object.
74        """
75        return _PRecordEvolver(self.__class__, self)
76
77    def __repr__(self):
78        return "{0}({1})".format(self.__class__.__name__,
79                                 ', '.join('{0}={1}'.format(k, repr(v)) for k, v in self.items()))
80
81    @classmethod
82    def create(cls, kwargs, _factory_fields=None, ignore_extra=False):
83        """
84        Factory method. Will create a new PRecord of the current type and assign the values
85        specified in kwargs.
86
87        :param ignore_extra: A boolean which when set to True will ignore any keys which appear in kwargs that are not
88                             in the set of fields on the PRecord.
89        """
90        if isinstance(kwargs, cls):
91            return kwargs
92
93        if ignore_extra:
94            kwargs = {k: kwargs[k] for k in cls._precord_fields if k in kwargs}
95
96        return cls(_factory_fields=_factory_fields, _ignore_extra=ignore_extra, **kwargs)
97
98    def __reduce__(self):
99        # Pickling support
100        return _restore_pickle, (self.__class__, dict(self),)
101
102    def serialize(self, format=None):
103        """
104        Serialize the current PRecord using custom serializer functions for fields where
105        such have been supplied.
106        """
107        return dict((k, serialize(self._precord_fields[k].serializer, format, v)) for k, v in self.items())
108
109
110class _PRecordEvolver(PMap._Evolver):
111    __slots__ = ('_destination_cls', '_invariant_error_codes', '_missing_fields', '_factory_fields', '_ignore_extra')
112
113    def __init__(self, cls, original_pmap, _factory_fields=None, _ignore_extra=False):
114        super(_PRecordEvolver, self).__init__(original_pmap)
115        self._destination_cls = cls
116        self._invariant_error_codes = []
117        self._missing_fields = []
118        self._factory_fields = _factory_fields
119        self._ignore_extra = _ignore_extra
120
121    def __setitem__(self, key, original_value):
122        self.set(key, original_value)
123
124    def set(self, key, original_value):
125        field = self._destination_cls._precord_fields.get(key)
126        if field:
127            if self._factory_fields is None or field in self._factory_fields:
128                try:
129                    if is_field_ignore_extra_complaint(PRecord, field, self._ignore_extra):
130                        value = field.factory(original_value, ignore_extra=self._ignore_extra)
131                    else:
132                        value = field.factory(original_value)
133                except InvariantException as e:
134                    self._invariant_error_codes += e.invariant_errors
135                    self._missing_fields += e.missing_fields
136                    return self
137            else:
138                value = original_value
139
140            check_type(self._destination_cls, field, key, value)
141
142            is_ok, error_code = field.invariant(value)
143            if not is_ok:
144                self._invariant_error_codes.append(error_code)
145
146            return super(_PRecordEvolver, self).set(key, value)
147        else:
148            raise AttributeError("'{0}' is not among the specified fields for {1}".format(key, self._destination_cls.__name__))
149
150    def persistent(self):
151        cls = self._destination_cls
152        is_dirty = self.is_dirty()
153        pm = super(_PRecordEvolver, self).persistent()
154        if is_dirty or not isinstance(pm, cls):
155            result = cls(_precord_buckets=pm._buckets, _precord_size=pm._size)
156        else:
157            result = pm
158
159        if cls._precord_mandatory_fields:
160            self._missing_fields += tuple('{0}.{1}'.format(cls.__name__, f) for f
161                                          in (cls._precord_mandatory_fields - set(result.keys())))
162
163        if self._invariant_error_codes or self._missing_fields:
164            raise InvariantException(tuple(self._invariant_error_codes), tuple(self._missing_fields),
165                                     'Field invariant failed')
166
167        check_global_invariants(result, cls._precord_invariants)
168
169        return result
170