1# Copyright (c) 2015, The MITRE Corporation. All rights reserved.
2# See LICENSE.txt for complete terms.
3"""
4Entity field data descriptors (TypedFields) and associated classes.
5"""
6import functools
7import inspect
8
9from .datautils import is_sequence, resolve_class
10from .typedlist import TypedList
11from .dates import parse_date, parse_datetime, serialize_date, serialize_datetime
12from .xml import strip_cdata, cdata
13from .vendor import six
14from .compat import long
15
16
17def unset(entity, *types):
18    """Unset the TypedFields on the input `entity`.
19
20    Args:
21        entity: A mixbox.Entity object.
22        *types: A variable-length list of TypedField subclasses. If not
23            provided, defaults to TypedField.
24    """
25    if not types:
26        types = (TypedField,)
27
28    fields = list(entity._fields.keys())
29    remove = (x for x in fields if isinstance(x, types))
30
31    for field in remove:
32        del entity._fields[field]
33
34
35def _matches(field, params):
36    """Return True if the input TypedField `field` contains instance attributes
37    that match the input parameters.
38
39    Args:
40        field: A TypedField instance.
41        params: A dictionary of TypedField instance attribute-to-value mappings.
42
43    Returns:
44        True if the input TypedField matches the input parameters.
45    """
46    fieldattrs = six.iteritems(params)
47    return all(getattr(field, attr) == val for attr, val in fieldattrs)
48
49
50def iterfields(klass):
51    """Iterate over the input class members and yield its TypedFields.
52
53    Args:
54        klass: A class (usually an Entity subclass).
55
56    Yields:
57        (class attribute name, TypedField instance) tuples.
58    """
59    is_field = lambda x: isinstance(x, TypedField)
60
61    for name, field in inspect.getmembers(klass, predicate=is_field):
62        yield name, field
63
64
65def find(entity, **kwargs):
66    """Return all TypedFields found on the input `Entity` that were initialized
67    with the input **kwargs.
68
69    Example:
70        >>> find(myentity, multiple=True, type_=Foo)
71
72    Note:
73        TypedFields.__init__() can accept a string or a class as a type_
74        argument, but this method expects a class.
75
76    Args:
77        **kwargs: TypedField __init__ **kwargs to search on.
78
79    Returns:
80        A list of TypedFields with matching **kwarg values.
81    """
82    try:
83        typedfields = entity.typed_fields()
84    except AttributeError:
85        typedfields = iterfields(entity.__class__)
86
87    matching = [x for x in typedfields if _matches(x, kwargs)]
88    return matching
89
90
91class TypedField(object):
92
93    def __init__(self, name, type_=None,
94                 key_name=None, comparable=True, multiple=False,
95                 preset_hook=None, postset_hook=None, factory=None,
96                 listfunc=None):
97        """
98        Create a new field.
99
100        Args:
101            `name` (str): name of the field as contained in the binding class.
102            `type_` (type/str): Required type for values assigned to this field.
103                If`None`, no type checking is performed. String values are
104                treated as fully qualified package paths to a class (e.g.,
105                "A.B.C" would be the full path to the type "C".)
106            `key_name` (str): name for field when represented as a dictionary.
107                (Optional) If omitted, `name.lower()` will be used.
108            `comparable` (boolean): whether this field should be considered
109                when checking Entities for equality. Default is True. If False,
110                this field is not considered.
111            `multiple` (boolean): Whether multiple instances of this field can
112                exist on the Entity.
113            `preset_hook` (callable): called before assigning a value to this
114                field, but after type checking is performed (if applicable).
115                This should typically be used to perform additional validation
116                checks on the value, perhaps based on current state of the
117                instance. The callable should accept two arguments: (1) the
118                instance object being modified, and (2)the value it is being
119                set to.
120            `postset_hook` (callable): similar to `preset_hook` (and takes the
121                same arguments), but is called after setting the value. This
122                can be used, for example, to modify other fields of the
123                instance to maintain some type of invariant.
124            `listfunc` (callable): A datatype or a function that creates a
125                mutable sequence type for multiple field internal storage.
126                E.g., "list".
127        """
128        self.name = name
129        self.comparable = comparable
130        self.multiple = multiple
131        self.preset_hook = preset_hook
132        self.postset_hook = postset_hook
133
134        # The type of the field. This is lazily set via the type_ property
135        # at first access.
136        self._unresolved_type = type_
137
138        # The factory for the field. This controls which class will be used
139        # for from_dict() and from_obj() calls for this field.
140        # Lazily set via the factory property.
141        self._unresolved_factory = factory
142
143        # Dictionary key name for the field.
144        if key_name:
145            self._key_name = key_name
146        else:
147            self._key_name = name.lower()
148
149        # List creation function for multiple fields.
150        if listfunc:
151            self._listfunc = listfunc
152        elif type_:
153            self._listfunc = functools.partial(TypedList, type=type_)
154        else:
155            self._listfunc = list
156
157    def __get__(self, instance, owner=None):
158        """Return the TypedField value for the input `instance` and `owner`.
159
160        If the TypedField is a "multiple" field and hasn't been set yet,
161        set the field to an empty list and return it.
162
163        Args:
164            instance: An instance of the `owner` class that this TypedField
165                belongs to..
166            owner: The TypedField owner class.
167        """
168        if instance is None:
169            return self
170        elif self in instance._fields:
171            return instance._fields[self]
172        elif self.multiple:
173            return instance._fields.setdefault(self, self._listfunc())
174        else:
175            return None
176
177    def _clean(self, value):
178        """Validate and clean a candidate value for this field."""
179        if value is None:
180            return None
181        elif self.type_ is None:
182            return value
183        elif self.check_type(value):
184            return value
185        elif self.is_type_castable:  # noqa
186            return self.type_(value)
187
188        error_fmt = "%s must be a %s, not a %s"
189        error = error_fmt % (self.name, self.type_, type(value))
190        raise TypeError(error)
191
192    def __set__(self, instance, value):
193        """Sets the field value on `instance` for this TypedField.
194
195        If the TypedField has a `type_` and `value` is not an instance of
196        ``type_``, an attempt may be made to convert `value` into an instance
197        of ``type_``.
198
199        If the field is ``multiple``, an attempt is made to convert `value`
200        into a list if it is not an iterable type.
201        """
202        if self.multiple:
203            if value is None:
204                value = self._listfunc()
205            elif not is_sequence(value):
206                value = self._listfunc([self._clean(value)])
207            else:
208                value = self._listfunc(self._clean(x) for x in value if x is not None)
209        else:
210            value = self._clean(value)
211
212        if self.preset_hook:
213            self.preset_hook(instance, value)
214
215        instance._fields[self] = value
216
217        if self.postset_hook:
218            self.postset_hook(instance, value)
219
220    def __str__(self):
221        return self.name
222
223    def check_type(self, value):
224        if not self.type_:
225            return True
226        elif hasattr(self.type_, "istypeof"):
227            return self.type_.istypeof(value)
228        else:
229            return isinstance(value, self.type_)
230
231    @property
232    def key_name(self):
233        return self._key_name
234
235    @property
236    def type_(self):
237        try:
238            return self._resolved_type
239        except AttributeError:
240            self._resolved_type = resolve_class(self._unresolved_type)
241        return self._resolved_type
242
243    @type_.setter
244    def type_(self, value):
245        self._resolved_type = value
246
247    @property
248    def factory(self):
249        try:
250            return self._resolved_factory
251        except AttributeError:
252            self._resolved_factory = resolve_class(self._unresolved_factory)
253        return self._resolved_factory
254
255    @factory.setter
256    def factory(self, value):
257        self._resolved_factory = value
258
259    @property
260    def transformer(self):
261        """Return the class for this field that transforms non-Entity objects
262        (e.g., dicts or binding objects) into Entity instances.
263
264        Any non-None value returned from this method should implement a
265        from_obj() and from_dict() method.
266
267        Returns:
268            None if no type_ or factory is defined by the field. Return a class
269            with from_dict and from_obj methods otherwise.
270        """
271        if self.factory:
272            return self.factory
273        elif self.type_:
274            return self.type_
275        else:
276            return None
277
278    @property
279    def is_type_castable(self):
280        return getattr(self.type_, "_try_cast", False)
281
282    def binding_value(self, value):
283        return value
284
285    def dict_value(self, value):
286        return value
287
288    def __copy__(self):
289        """See __deepcopy__."""
290        return self
291
292    def __deepcopy__(self, memo):
293        """Return itself (don't actually make a copy at all).
294
295        TypedFields store themselves as a key in an Entity._fields dictionary
296        and use themselves as a key for value retrieval.
297
298        The deepcopy() function would normally descend into the _fields dictionary
299        of an Entity and replace the keys with *copies* of the original
300        TypedFields.
301
302        As such, a TypedField would never find itself in a deepcopied Entity,
303        because the _fields dictionary had its keys swapped out for copies
304        of the original TypedField.
305
306        We could control __deepcopy__ at the Entity level, but it's a fair
307        amount more complicated and ultimately, we probably never want
308        TypedFields to actually be copied since they are class-level
309        property descriptors.
310        """
311        memo[id(self)] = self  # add self to the memo so this isn't called again.
312        return self
313
314
315class BytesField(TypedField):
316    def _clean(self, value):
317        return six.binary_type(value)
318
319
320class TextField(TypedField):
321    def _clean(self, value):
322        return six.text_type(value)
323
324
325class BooleanField(TypedField):
326    def _clean(self, value):
327        return bool(value)
328
329
330class IntegerField(TypedField):
331    def _clean(self, value):
332        if value in (None, ""):
333            return None
334        elif isinstance(value, six.string_types):
335            return int(value, 0)
336        else:
337            return int(value)
338
339
340class LongField(TypedField):
341    def _clean(self, value):
342        if value in (None, ""):
343            return None
344        elif isinstance(value, six.string_types):
345            return long(value, 0)
346        else:
347            return long(value)
348
349
350class FloatField(TypedField):
351    def _clean(self, value):
352        if value in (None, ""):
353            return None
354        return float(value)
355
356
357class DateTimeField(TypedField):
358    def _clean(self, value):
359        return parse_datetime(value)
360
361    def dict_value(self, value):
362        return serialize_datetime(value)
363
364    def binding_value(self, value):
365        return serialize_datetime(value)
366
367
368class DateField(TypedField):
369    def _clean(self, value):
370        return parse_date(value)
371
372    def dict_value(self, value):
373        return serialize_date(value)
374
375    def binding_value(self, value):
376        return serialize_date(value)
377
378
379class CDATAField(TypedField):
380    def _clean(self, value):
381        return strip_cdata(value)
382
383    def binding_value(self, value):
384        return cdata(value)
385
386
387class IdField(TypedField):
388    def __set__(self, instance, value):
389        """Set the id field to `value`. If `value` is not None or an empty
390        string, unset the idref fields on `instance`.
391        """
392        super(IdField, self).__set__(instance, value)
393
394        if value:
395            unset(instance, IdrefField)
396
397
398class IdrefField(TypedField):
399    def __set__(self, instance, value):
400        """Set the idref field to `value`. If `value` is not None or an empty
401        string, unset the id fields on `instance`.
402        """
403        super(IdrefField, self).__set__(instance, value)
404
405        if value:
406            unset(instance, IdField)
407