1import six 2from pyrsistent._checked_types import CheckedType, _restore_pickle, InvariantException, store_invariants 3from pyrsistent._field_common import ( 4 set_fields, check_type, is_field_ignore_extra_complaint, PFIELD_NO_INITIAL, serialize, check_global_invariants 5) 6from pyrsistent._pmap import PMap, pmap 7 8 9class _PRecordMeta(type): 10 def __new__(mcs, name, bases, dct): 11 set_fields(dct, bases, name='_precord_fields') 12 store_invariants(dct, bases, '_precord_invariants', '__invariant__') 13 14 dct['_precord_mandatory_fields'] = \ 15 set(name for name, field in dct['_precord_fields'].items() if field.mandatory) 16 17 dct['_precord_initial_values'] = \ 18 dict((k, field.initial) for k, field in dct['_precord_fields'].items() if field.initial is not PFIELD_NO_INITIAL) 19 20 21 dct['__slots__'] = () 22 23 return super(_PRecordMeta, mcs).__new__(mcs, name, bases, dct) 24 25 26@six.add_metaclass(_PRecordMeta) 27class PRecord(PMap, CheckedType): 28 """ 29 A PRecord is a PMap with a fixed set of specified fields. Records are declared as python classes inheriting 30 from PRecord. Because it is a PMap it has full support for all Mapping methods such as iteration and element 31 access using subscript notation. 32 33 More documentation and examples of PRecord usage is available at https://github.com/tobgu/pyrsistent 34 """ 35 def __new__(cls, **kwargs): 36 # Hack total! If these two special attributes exist that means we can create 37 # ourselves. Otherwise we need to go through the Evolver to create the structures 38 # for us. 39 if '_precord_size' in kwargs and '_precord_buckets' in kwargs: 40 return super(PRecord, cls).__new__(cls, kwargs['_precord_size'], kwargs['_precord_buckets']) 41 42 factory_fields = kwargs.pop('_factory_fields', None) 43 ignore_extra = kwargs.pop('_ignore_extra', False) 44 45 initial_values = kwargs 46 if cls._precord_initial_values: 47 initial_values = dict((k, v() if callable(v) else v) 48 for k, v in cls._precord_initial_values.items()) 49 initial_values.update(kwargs) 50 51 e = _PRecordEvolver(cls, pmap(), _factory_fields=factory_fields, _ignore_extra=ignore_extra) 52 for k, v in initial_values.items(): 53 e[k] = v 54 55 return e.persistent() 56 57 def set(self, *args, **kwargs): 58 """ 59 Set a field in the record. This set function differs slightly from that in the PMap 60 class. First of all it accepts key-value pairs. Second it accepts multiple key-value 61 pairs to perform one, atomic, update of multiple fields. 62 """ 63 64 # The PRecord set() can accept kwargs since all fields that have been declared are 65 # valid python identifiers. Also allow multiple fields to be set in one operation. 66 if args: 67 return super(PRecord, self).set(args[0], args[1]) 68 69 return self.update(kwargs) 70 71 def evolver(self): 72 """ 73 Returns an evolver of this object. 74 """ 75 return _PRecordEvolver(self.__class__, self) 76 77 def __repr__(self): 78 return "{0}({1})".format(self.__class__.__name__, 79 ', '.join('{0}={1}'.format(k, repr(v)) for k, v in self.items())) 80 81 @classmethod 82 def create(cls, kwargs, _factory_fields=None, ignore_extra=False): 83 """ 84 Factory method. Will create a new PRecord of the current type and assign the values 85 specified in kwargs. 86 87 :param ignore_extra: A boolean which when set to True will ignore any keys which appear in kwargs that are not 88 in the set of fields on the PRecord. 89 """ 90 if isinstance(kwargs, cls): 91 return kwargs 92 93 if ignore_extra: 94 kwargs = {k: kwargs[k] for k in cls._precord_fields if k in kwargs} 95 96 return cls(_factory_fields=_factory_fields, _ignore_extra=ignore_extra, **kwargs) 97 98 def __reduce__(self): 99 # Pickling support 100 return _restore_pickle, (self.__class__, dict(self),) 101 102 def serialize(self, format=None): 103 """ 104 Serialize the current PRecord using custom serializer functions for fields where 105 such have been supplied. 106 """ 107 return dict((k, serialize(self._precord_fields[k].serializer, format, v)) for k, v in self.items()) 108 109 110class _PRecordEvolver(PMap._Evolver): 111 __slots__ = ('_destination_cls', '_invariant_error_codes', '_missing_fields', '_factory_fields', '_ignore_extra') 112 113 def __init__(self, cls, original_pmap, _factory_fields=None, _ignore_extra=False): 114 super(_PRecordEvolver, self).__init__(original_pmap) 115 self._destination_cls = cls 116 self._invariant_error_codes = [] 117 self._missing_fields = [] 118 self._factory_fields = _factory_fields 119 self._ignore_extra = _ignore_extra 120 121 def __setitem__(self, key, original_value): 122 self.set(key, original_value) 123 124 def set(self, key, original_value): 125 field = self._destination_cls._precord_fields.get(key) 126 if field: 127 if self._factory_fields is None or field in self._factory_fields: 128 try: 129 if is_field_ignore_extra_complaint(PRecord, field, self._ignore_extra): 130 value = field.factory(original_value, ignore_extra=self._ignore_extra) 131 else: 132 value = field.factory(original_value) 133 except InvariantException as e: 134 self._invariant_error_codes += e.invariant_errors 135 self._missing_fields += e.missing_fields 136 return self 137 else: 138 value = original_value 139 140 check_type(self._destination_cls, field, key, value) 141 142 is_ok, error_code = field.invariant(value) 143 if not is_ok: 144 self._invariant_error_codes.append(error_code) 145 146 return super(_PRecordEvolver, self).set(key, value) 147 else: 148 raise AttributeError("'{0}' is not among the specified fields for {1}".format(key, self._destination_cls.__name__)) 149 150 def persistent(self): 151 cls = self._destination_cls 152 is_dirty = self.is_dirty() 153 pm = super(_PRecordEvolver, self).persistent() 154 if is_dirty or not isinstance(pm, cls): 155 result = cls(_precord_buckets=pm._buckets, _precord_size=pm._size) 156 else: 157 result = pm 158 159 if cls._precord_mandatory_fields: 160 self._missing_fields += tuple('{0}.{1}'.format(cls.__name__, f) for f 161 in (cls._precord_mandatory_fields - set(result.keys()))) 162 163 if self._invariant_error_codes or self._missing_fields: 164 raise InvariantException(tuple(self._invariant_error_codes), tuple(self._missing_fields), 165 'Field invariant failed') 166 167 check_global_invariants(result, cls._precord_invariants) 168 169 return result 170