1"""This module provides an API to validate and to some extent
2   manipulate data structures, such as JSON and XML parsing results.
3
4
5   Example usage:
6
7   >>> validate(int, 5)
8   5
9
10   >>> validate({text: int}, {'foo': '1'})
11   ValueError: Type of '1' should be 'int' but is 'str'
12
13   >>> validate({'foo': transform(int)}, {'foo': '1'})
14   {'foo': 1}
15
16"""
17
18
19from xml.etree import ElementTree as ET
20from copy import copy as copy_obj
21
22try:
23    from functools import singledispatch
24except ImportError:
25    from singledispatch import singledispatch
26
27from ...compat import is_py2, urlparse
28from ...exceptions import PluginError
29
30__all__ = [
31    "any", "all", "filter", "get", "getattr", "hasattr", "length", "optional",
32    "transform", "text", "union", "url", "startswith", "endswith",
33    "xml_element", "xml_find", "xml_findall", "xml_findtext",
34    "validate", "Schema", "SchemaContainer"
35]
36
37#: Alias for text type on each Python version
38text = is_py2 and basestring or str
39
40# References to original functions that we override in this module
41_all = all
42_getattr = getattr
43_hasattr = hasattr
44_filter = filter
45_map = map
46
47
48_re_match_attr = ("group", "groups", "groupdict", "re")
49def _is_re_match(value):
50    return _all(_hasattr(value, a) for a in _re_match_attr)
51
52
53class any(tuple):
54    """At least one of the schemas must be valid."""
55    def __new__(cls, *args):
56        return super(any, cls).__new__(cls, args)
57
58
59class all(tuple):
60    """All schemas must be valid."""
61    def __new__(cls, *args):
62        return super(all, cls).__new__(cls, args)
63
64
65class SchemaContainer(object):
66    def __init__(self, schema):
67        self.schema = schema
68
69
70class transform(object):
71    """Applies function to value to transform it."""
72    def __init__(self, func):
73        # text is an alias for basestring on Python 2, which cannot be
74        # instantiated and therefore can't be used to transform the value,
75        # so we force to unicode instead.
76        if is_py2 and func == text:
77            func = unicode
78
79        self.func = func
80
81
82class optional(object):
83    """An optional key used in a dict or union-dict."""
84    def __init__(self, key):
85        self.key = key
86
87
88class union(SchemaContainer):
89    """Extracts multiple validations based on the same value."""
90
91
92class attr(SchemaContainer):
93    """Validates an object's attributes."""
94
95
96class xml_element(object):
97    """A XML element."""
98    def __init__(self, tag=None, text=None, attrib=None):
99        self.tag = tag
100        self.text = text
101        self.attrib = attrib
102
103
104def length(length):
105    """Checks value for minimum length using len()."""
106    def min_len(value):
107        if not len(value) >= length:
108            raise ValueError(
109                "Minimum length is {0} but value is {1}".format(length, len(value))
110            )
111        return True
112
113    return min_len
114
115
116def startswith(string):
117    """Checks if the string value starts with another string."""
118    def starts_with(value):
119        validate(text, value)
120        if not value.startswith(string):
121            raise ValueError("'{0}' does not start with '{1}'".format(value, string))
122        return True
123
124    return starts_with
125
126
127def endswith(string):
128    """Checks if the string value ends with another string."""
129    def ends_with(value):
130        validate(text, value)
131        if not value.endswith(string):
132            raise ValueError("'{0}' does not end with '{1}'".format(value, string))
133        return True
134
135    return ends_with
136
137
138def get(item, default=None):
139    """Get item from value (value[item]).
140
141    If the item is not found, return the default.
142
143    Handles XML elements, regex matches and anything that has __getitem__.
144    """
145
146    def getter(value):
147        if ET.iselement(value):
148            value = value.attrib
149
150        try:
151            # Use .group() if this is a regex match object
152            if _is_re_match(value):
153                return value.group(item)
154            else:
155                return value[item]
156        except (KeyError, IndexError):
157            return default
158        except (TypeError, AttributeError) as err:
159            raise ValueError(err)
160
161    return transform(getter)
162
163
164def getattr(attr, default=None):
165    """Get a named attribute from an object.
166
167    When a default argument is given, it is returned when the attribute
168    doesn't exist.
169    """
170    def getter(value):
171        return _getattr(value, attr, default)
172
173    return transform(getter)
174
175
176def hasattr(attr):
177    """Verifies that the object has an attribute with the given name."""
178    def has_attr(value):
179        return _hasattr(value, attr)
180
181    return has_attr
182
183
184def filter(func):
185    """Filters out unwanted items using the specified function.
186
187    Supports both dicts and sequences, key/value pairs are
188    expanded when applied to a dict.
189    """
190    def expand_kv(kv):
191        return func(*kv)
192
193    def filter_values(value):
194        cls = type(value)
195        if isinstance(value, dict):
196            return cls(_filter(expand_kv, value.items()))
197        else:
198            return cls(_filter(func, value))
199
200    return transform(filter_values)
201
202
203def map(func):
204    """Apply function to each value inside the sequence or dict.
205
206    Supports both dicts and sequences, key/value pairs are
207    expanded when applied to a dict.
208    """
209    # text is an alias for basestring on Python 2, which cannot be
210    # instantiated and therefore can't be used to transform the value,
211    # so we force to unicode instead.
212    if is_py2 and text == func:
213        func = unicode
214
215    def expand_kv(kv):
216        return func(*kv)
217
218    def map_values(value):
219        cls = type(value)
220        if isinstance(value, dict):
221            return cls(_map(expand_kv, value.items()))
222        else:
223            return cls(_map(func, value))
224
225    return transform(map_values)
226
227
228def url(**attributes):
229    """Parses an URL and validates its attributes."""
230    def check_url(value):
231        validate(text, value)
232        parsed = urlparse(value)
233        if not parsed.netloc:
234            raise ValueError("'{0}' is not a valid URL".format(value))
235
236        for name, schema in attributes.items():
237            if not _hasattr(parsed, name):
238                raise ValueError("Invalid URL attribute '{0}'".format(name))
239
240            try:
241                validate(schema, _getattr(parsed, name))
242            except ValueError as err:
243                raise ValueError(
244                    "Unable to validate URL attribute '{0}': {1}".format(
245                        name, err
246                    )
247                )
248
249        return True
250
251    # Convert "http" to be either any("http", "https") for convenience
252    if attributes.get("scheme") == "http":
253        attributes["scheme"] = any("http", "https")
254
255    return check_url
256
257
258def xml_find(xpath):
259    """Find a XML element via xpath."""
260    def xpath_find(value):
261        validate(ET.iselement, value)
262        value = value.find(xpath)
263        if value is None:
264            raise ValueError("XPath '{0}' did not return an element".format(xpath))
265
266        return validate(ET.iselement, value)
267
268    return transform(xpath_find)
269
270
271def xml_findall(xpath):
272    """Find a list of XML elements via xpath."""
273    def xpath_findall(value):
274        validate(ET.iselement, value)
275        return value.findall(xpath)
276
277    return transform(xpath_findall)
278
279
280def xml_findtext(xpath):
281    """Find a XML element via xpath and extract its text."""
282    return all(
283        xml_find(xpath),
284        getattr("text"),
285    )
286
287
288@singledispatch
289def validate(schema, value):
290    if callable(schema):
291        if schema(value):
292            return value
293        else:
294            raise ValueError("{0}({1!r}) is not true".format(schema.__name__, value))
295
296    if schema == value:
297        return value
298    else:
299        raise ValueError("{0!r} does not equal {1!r}".format(value, schema))
300
301
302@validate.register(any)
303def validate_any(schema, value):
304    errors = []
305    for subschema in schema:
306        try:
307            return validate(subschema, value)
308        except ValueError as err:
309            errors.append(err)
310    else:
311        err = " or ".join(_map(str, errors))
312        raise ValueError(err)
313
314
315@validate.register(all)
316def validate_all(schemas, value):
317    for schema in schemas:
318        value = validate(schema, value)
319
320    return value
321
322
323@validate.register(transform)
324def validate_transform(schema, value):
325    validate(callable, schema.func)
326    return schema.func(value)
327
328
329@validate.register(list)
330@validate.register(tuple)
331@validate.register(set)
332@validate.register(frozenset)
333def validate_sequence(schema, value):
334    validate(type(schema), value)
335    return type(schema)(validate(any(*schema), v) for v in value)
336
337
338@validate.register(dict)
339def validate_dict(schema, value):
340    validate(type(schema), value)
341    new = type(schema)()
342
343    for key, subschema in schema.items():
344        if isinstance(key, optional):
345            if key.key not in value:
346                continue
347            key = key.key
348
349        if type(key) in (type, transform, any, all, union):
350            for subkey, subvalue in value.items():
351                new[validate(key, subkey)] = validate(subschema, subvalue)
352            break
353        else:
354            if key not in value:
355                raise ValueError("Key '{0}' not found in {1!r}".format(key, value))
356
357            try:
358                new[key] = validate(subschema, value[key])
359            except ValueError as err:
360                raise ValueError("Unable to validate key '{0}': {1}".format(key, err))
361
362    return new
363
364
365@validate.register(type)
366def validate_type(schema, value):
367    if isinstance(value, schema):
368        return value
369    else:
370        raise ValueError(
371            "Type of {0!r} should be '{1}' but is '{2}'".format(
372                value, schema.__name__, type(value).__name__
373            )
374        )
375
376
377@validate.register(xml_element)
378def validate_xml_element(schema, value):
379    validate(ET.iselement, value)
380    new = ET.Element(value.tag, attrib=value.attrib)
381
382    if schema.attrib is not None:
383        try:
384            new.attrib = validate(schema.attrib, value.attrib)
385        except ValueError as err:
386            raise ValueError("Unable to validate XML attributes: {0}".format(err))
387
388    if schema.tag is not None:
389        try:
390            new.tag = validate(schema.tag, value.tag)
391        except ValueError as err:
392            raise ValueError("Unable to validate XML tag: {0}".format(err))
393
394    if schema.text is not None:
395        try:
396            new.text = validate(schema.text, value.text)
397        except ValueError as err:
398            raise ValueError("Unable to validate XML text: {0}".format(err))
399
400    for child in value:
401        new.append(child)
402
403    return new
404
405
406@validate.register(attr)
407def validate_attr(schema, value):
408    new = copy_obj(value)
409
410    for attr, schema in schema.schema.items():
411        if not _hasattr(value, attr):
412            raise ValueError("Attribute '{0}' not found on object '{1}'".format(
413                attr, value
414            ))
415
416        setattr(new, attr, validate(schema, _getattr(value, attr)))
417
418    return new
419
420
421@singledispatch
422def validate_union(schema, value):
423    raise ValueError("Invalid union type: {0}".format(type(schema).__name__))
424
425
426@validate_union.register(dict)
427def validate_union_dict(schema, value):
428    new = type(schema)()
429    for key, schema in schema.items():
430        optional_ = isinstance(key, optional)
431        if optional_:
432            key = key.key
433
434        try:
435            new[key] = validate(schema, value)
436        except ValueError as err:
437            if optional_:
438                continue
439
440            raise ValueError("Unable to validate union '{0}': {1}".format(key, err))
441
442    return new
443
444
445@validate_union.register(list)
446@validate_union.register(tuple)
447@validate_union.register(set)
448@validate_union.register(frozenset)
449def validate_union_sequence(schemas, value):
450    return type(schemas)(validate(schema, value) for schema in schemas)
451
452
453@validate.register(union)
454def validate_unions(schema, value):
455    return validate_union(schema.schema, value)
456
457
458class Schema(object):
459    """Wraps a validator schema into a object."""
460
461    def __init__(self, *schemas):
462        self.schema = all(*schemas)
463
464    def validate(self, value, name="result", exception=PluginError):
465        try:
466            return validate(self.schema, value)
467        except ValueError as err:
468            raise exception("Unable to validate {0}: {1}".format(name, err))
469
470
471@validate.register(Schema)
472def validate_schema(schema, value):
473    return schema.validate(value, exception=ValueError)
474
475