1#
2# Copyright (c), 2016-2021, SISSA (International School for Advanced Studies).
3# All rights reserved.
4# This file is distributed under the terms of the MIT License.
5# See the file 'LICENSE' in the root directory of the present
6# distribution, or http://opensource.org/licenses/MIT.
7#
8# @author Davide Brunato <brunato@sissa.it>
9#
10import sys
11if sys.version_info < (3, 7):
12    from typing import GenericMeta as ABCMeta
13else:
14    from abc import ABCMeta
15
16from itertools import count
17from typing import TYPE_CHECKING, cast, overload, Any, Dict, List, Iterator, \
18    Optional, Union, Tuple, Type, MutableMapping, MutableSequence
19from elementpath import XPathContext, XPath2Parser
20
21from .exceptions import XMLSchemaAttributeError, XMLSchemaTypeError, XMLSchemaValueError
22from .etree import ElementData, etree_tostring
23from .aliases import ElementType, XMLSourceType, NamespacesType, BaseXsdType, DecodeType
24from .helpers import get_namespace, get_prefixed_qname, local_name, raw_xml_encode
25from .converters import XMLSchemaConverter
26from .resources import XMLResource
27from . import validators
28
29if TYPE_CHECKING:
30    from .validators import XMLSchemaValidationError, XsdElement
31
32
33class DataElement(MutableSequence['DataElement']):
34    """
35    Data Element, an Element like object with decoded data and schema bindings.
36
37    :param tag: a string containing a QName in extended format.
38    :param value: the simple typed value of the element.
39    :param attrib: the typed attributes of the element.
40    :param nsmap: an optional map from prefixes to namespaces.
41    :param xsd_element: an optional XSD element association.
42    :param xsd_type: an optional XSD type association. Can be provided \
43    also if the instance is not bound with an XSD element.
44    """
45    _children: List['DataElement']
46    tag: str
47    attrib: Dict[str, Any]
48    nsmap: Dict[str, str]
49
50    value: Optional[Any] = None
51    tail: Optional[str] = None
52    xsd_element: Optional['XsdElement'] = None
53    xsd_type: Optional[BaseXsdType] = None
54    _encoder: Optional['XsdElement'] = None
55
56    def __init__(self, tag: str,
57                 value: Optional[Any] = None,
58                 attrib: Optional[Dict[str, Any]] = None,
59                 nsmap: Optional[MutableMapping[str, str]] = None,
60                 xsd_element: Optional['XsdElement'] = None,
61                 xsd_type: Optional[BaseXsdType] = None) -> None:
62
63        super(DataElement, self).__init__()
64        self._children = []
65        self.tag = tag
66        self.attrib = {}
67        self.nsmap = {}
68
69        if value is not None:
70            self.value = value
71        if attrib is not None:
72            self.attrib.update(attrib)
73        if nsmap is not None:
74            self.nsmap.update(nsmap)
75
76        if xsd_element is not None:
77            self.xsd_element = xsd_element
78            self.xsd_type = xsd_type or xsd_element.type
79        elif xsd_type is not None:
80            self.xsd_type = xsd_type
81        elif self.xsd_element is not None:
82            self._encoder = self.xsd_element
83
84    @overload
85    def __getitem__(self, i: int) -> 'DataElement': ...
86
87    @overload
88    def __getitem__(self, s: slice) -> MutableSequence['DataElement']: ...
89
90    def __getitem__(self, i: Union[int, slice]) \
91            -> Union['DataElement', MutableSequence['DataElement']]:
92        return self._children[i]
93
94    def __setitem__(self, i: Union[int, slice], child: Any) -> None:
95        self._children[i] = child
96
97    def __delitem__(self, i: Union[int, slice]) -> None:
98        del self._children[i]
99
100    def __len__(self) -> int:
101        return len(self._children)
102
103    def insert(self, i: int, child: 'DataElement') -> None:
104        assert isinstance(child, DataElement)
105        self._children.insert(i, child)
106
107    def __repr__(self) -> str:
108        return '%s(tag=%r)' % (self.__class__.__name__, self.tag)
109
110    def __iter__(self) -> Iterator['DataElement']:
111        yield from self._children
112
113    def __setattr__(self, key: str, value: Any) -> None:
114        if key == 'xsd_element':
115            if not isinstance(value, validators.XsdElement):
116                raise XMLSchemaTypeError("invalid type for attribute 'xsd_element'")
117            elif self.xsd_element is value:
118                pass
119            elif self.xsd_element is not None:
120                raise XMLSchemaValueError("the instance is already bound to another XSD element")
121            elif self.xsd_type is not None and self.xsd_type is not value.type:
122                raise XMLSchemaValueError("the instance is already bound to another XSD type")
123
124        elif key == 'xsd_type':
125            if not isinstance(value, (validators.XsdSimpleType, validators.XsdComplexType)):
126                raise XMLSchemaTypeError("invalid type for attribute 'xsd_type'")
127            elif self.xsd_type is not None and self.xsd_type is not value:
128                raise XMLSchemaValueError("the instance is already bound to another XSD type")
129            elif self.xsd_element is None or value is not self.xsd_element.type:
130                self._encoder = value.schema.create_element(
131                    self.tag, parent=value, form='unqualified'
132                )
133                self._encoder.type = value
134            else:
135                self._encoder = self.xsd_element
136
137        super(DataElement, self).__setattr__(key, value)
138
139    @property
140    def text(self) -> Optional[str]:
141        """The string value of the data element."""
142        return raw_xml_encode(self.value)
143
144    def get(self, key: str, default: Any = None) -> Any:
145        """Gets a data element attribute."""
146        return self.attrib.get(key, default)
147
148    def set(self, key: str, value: Any) -> None:
149        """Sets a data element attribute."""
150        self.attrib[key] = value
151
152    @property
153    def xsd_version(self) -> str:
154        return '1.0' if self.xsd_element is None else self.xsd_element.xsd_version
155
156    @property
157    def namespace(self) -> str:
158        """The element's namespace."""
159        if self.xsd_element is None:
160            return get_namespace(self.tag)
161        return get_namespace(self.tag) or self.xsd_element.target_namespace
162
163    @property
164    def name(self) -> str:
165        """The element's name, that matches the tag."""
166        return self.tag
167
168    @property
169    def prefixed_name(self) -> str:
170        """The prefixed name, or the tag if no prefix is defined for its namespace."""
171        return get_prefixed_qname(self.tag, self.nsmap)
172
173    @property
174    def local_name(self) -> str:
175        """The local part of the tag."""
176        return local_name(self.tag)
177
178    def validate(self, use_defaults: bool = True,
179                 namespaces: Optional[NamespacesType] = None,
180                 max_depth: Optional[int] = None) -> None:
181        """
182        Validates the XML data object.
183
184        :param use_defaults: whether to use default values for filling missing data.
185        :param namespaces: is an optional mapping from namespace prefix to URI. \
186        For default uses the namespace map of the XML data object.
187        :param max_depth: maximum depth for validation, for default there is no limit.
188        :raises: :exc:`XMLSchemaValidationError` if XML data object is not valid.
189        :raises: :exc:`XMLSchemaValueError` if the instance has no schema bindings.
190        """
191        for error in self.iter_errors(use_defaults, namespaces, max_depth):
192            raise error
193
194    def is_valid(self, use_defaults: bool = True,
195                 namespaces: Optional[NamespacesType] = None,
196                 max_depth: Optional[int] = None) -> bool:
197        """
198        Like :meth:`validate` except it does not raise an exception on validation
199        error but returns ``True`` if the XML data object is valid, ``False`` if
200        it's invalid.
201
202        :raises: :exc:`XMLSchemaValueError` if the instance has no schema bindings.
203        """
204        error = next(self.iter_errors(use_defaults, namespaces, max_depth), None)
205        return error is None
206
207    def iter_errors(self, use_defaults: bool = True,
208                    namespaces: Optional[NamespacesType] = None,
209                    max_depth: Optional[int] = None) -> Iterator['XMLSchemaValidationError']:
210        """
211        Generates a sequence of validation errors if the XML data object is invalid.
212        Accepts the same arguments of :meth:`validate`.
213        """
214        if self._encoder is None:
215            raise XMLSchemaValueError("{!r} has no schema bindings".format(self))
216
217        kwargs: Dict[str, Any] = {
218            'converter': DataElementConverter,
219            'use_defaults': use_defaults,
220        }
221        if namespaces:
222            kwargs['namespaces'] = namespaces
223        if isinstance(max_depth, int) and max_depth >= 0:
224            kwargs['max_depth'] = max_depth
225
226        for result in self._encoder.iter_encode(self, **kwargs):
227            if isinstance(result, validators.XMLSchemaValidationError):
228                yield result
229            else:
230                del result
231
232    def encode(self, validation: str = 'strict', **kwargs: Any) \
233            -> Union[ElementType, Tuple[ElementType, List['XMLSchemaValidationError']]]:
234        """
235        Encodes the data object to XML.
236
237        :param validation: the validation mode. Can be 'lax', 'strict' or 'skip.
238        :param kwargs: optional keyword arguments for the method :func:`iter_encode` \
239        of :class:`XsdElement`.
240        :return: An ElementTree's Element. If *validation* argument is 'lax' a \
241        2-items tuple is returned, where the first item is the encoded object and \
242        the second item is a list with validation errors.
243        :raises: :exc:`XMLSchemaValidationError` if the object is invalid \
244        and ``validation='strict'``.
245        """
246        if 'converter' not in kwargs:
247            kwargs['converter'] = DataElementConverter
248
249        encoder: Union['XsdElement', BaseXsdType]
250        if self._encoder is not None:
251            encoder = self._encoder
252        elif validation == 'skip':
253            encoder = validators.XMLSchema.builtin_types()['anyType']
254        else:
255            raise XMLSchemaValueError("{!r} has no schema bindings".format(self))
256
257        return encoder.encode(self, validation=validation, **kwargs)
258
259    to_etree = encode
260
261    def tostring(self, indent: str = '', max_lines: Optional[int] = None,
262                 spaces_for_tab: int = 4) -> Any:
263        """Serializes the data element tree to an XML source string."""
264        root, errors = self.encode(validation='lax')
265        return etree_tostring(root, self.nsmap, indent, max_lines, spaces_for_tab)
266
267    def find(self, path: str,
268             namespaces: Optional[NamespacesType] = None) -> Optional['DataElement']:
269        """
270        Finds the first data element matching the path.
271
272        :param path: an XPath expression that considers the data element as the root.
273        :param namespaces: an optional mapping from namespace prefix to namespace URI.
274        :return: the first matching data element or ``None`` if there is no match.
275        """
276        parser = XPath2Parser(namespaces, strict=False)
277        context = XPathContext(cast(Any, self))
278        result = next(parser.parse(path).select_results(context), None)
279        return result if isinstance(result, DataElement) else None
280
281    def findall(self, path: str,
282                namespaces: Optional[NamespacesType] = None) -> List['DataElement']:
283        """
284        Finds all data elements matching the path.
285
286        :param path: an XPath expression that considers the data element as the root.
287        :param namespaces: an optional mapping from namespace prefix to full name.
288        :return: a list containing all matching data elements in document order, \
289        an empty list is returned if there is no match.
290        """
291        parser = XPath2Parser(namespaces, strict=False)
292        context = XPathContext(cast(Any, self))
293        results = parser.parse(path).get_results(context)
294        if not isinstance(results, list):
295            return []
296        return [e for e in results if isinstance(e, DataElement)]
297
298    def iterfind(self, path: str,
299                 namespaces: Optional[NamespacesType] = None) -> Iterator['DataElement']:
300        """
301        Creates and iterator for all XSD subelements matching the path.
302
303        :param path: an XPath expression that considers the data element as the root.
304        :param namespaces: is an optional mapping from namespace prefix to full name.
305        :return: an iterable yielding all matching data elements in document order.
306        """
307        parser = XPath2Parser(namespaces, strict=False)
308        context = XPathContext(cast(Any, self))
309        results = parser.parse(path).select_results(context)
310        yield from filter(lambda x: isinstance(x, DataElement), results)  # type: ignore[misc]
311
312    def iter(self, tag: Optional[str] = None) -> Iterator['DataElement']:
313        """
314        Creates an iterator for the data element and its subelements. If tag
315        is not `None` or '*', only data elements whose matches tag are returned
316        from the iterator.
317        """
318        if tag == '*':
319            tag = None
320        if tag is None or tag == self.tag:
321            yield self
322        for child in self._children:
323            yield from child.iter(tag)
324
325    def iterchildren(self, tag: Optional[str] = None) -> Iterator['DataElement']:
326        """
327        Creates an iterator for the child data elements. If *tag* is not `None` or '*',
328        only data elements whose name matches tag are returned from the iterator.
329        """
330        if tag == '*':
331            tag = None
332        for child in self:
333            if tag is None or tag == child.tag:
334                yield child
335
336
337class DataBindingMeta(ABCMeta):
338    """Metaclass for creating classes with bindings to XSD elements."""
339
340    xsd_element: 'XsdElement'
341
342    def __new__(mcs, name: str, bases: Tuple[Type[Any], ...],
343                attrs: Dict[str, Any]) -> 'DataBindingMeta':
344        try:
345            xsd_element = attrs['xsd_element']
346        except KeyError:
347            msg = "attribute 'xsd_element' is required for an XSD data binding class"
348            raise XMLSchemaAttributeError(msg) from None
349
350        if not isinstance(xsd_element, validators.XsdElement):
351            raise XMLSchemaTypeError("{!r} is not an XSD element".format(xsd_element))
352
353        attrs['__module__'] = None
354        return super(DataBindingMeta, mcs).__new__(mcs, name, bases, attrs)
355
356    def __init__(cls, name: str, bases: Tuple[Type[Any], ...], attrs: Dict[str, Any]) -> None:
357        super(DataBindingMeta, cls).__init__(name, bases, attrs)
358        cls.xsd_version = cls.xsd_element.xsd_version
359        cls.namespace = cls.xsd_element.target_namespace
360
361    def fromsource(cls, source: Union[XMLSourceType, XMLResource],
362                   allow: str = 'all', defuse: str = 'remote',
363                   timeout: int = 300, **kwargs: Any) -> DecodeType[Any]:
364        if not isinstance(source, XMLResource):
365            source = XMLResource(source, allow=allow, defuse=defuse, timeout=timeout)
366        if 'converter' not in kwargs:
367            kwargs['converter'] = DataBindingConverter
368        return cls.xsd_element.schema.decode(source, **kwargs)
369
370
371class DataElementConverter(XMLSchemaConverter):
372    """
373    XML Schema based converter class for DataElement objects.
374
375    :param namespaces: a dictionary map from namespace prefixes to URI.
376    :param data_element_class: MutableSequence subclass to use for decoded data. \
377    Default is `DataElement`.
378    """
379    __slots__ = 'data_element_class',
380
381    def __init__(self, namespaces: Optional[NamespacesType] = None,
382                 data_element_class: Optional[Type['DataElement']] = None,
383                 **kwargs: Any) -> None:
384        if data_element_class is None:
385            self.data_element_class = DataElement
386        else:
387            self.data_element_class = data_element_class
388        kwargs.update(attr_prefix='', text_key='', cdata_prefix='')
389        super(DataElementConverter, self).__init__(namespaces, **kwargs)
390
391    @property
392    def lossy(self) -> bool:
393        return False
394
395    @property
396    def losslessly(self) -> bool:
397        return True
398
399    def copy(self, **kwargs: Any) -> 'DataElementConverter':
400        obj = cast(DataElementConverter, super().copy(**kwargs))
401        obj.data_element_class = kwargs.get('data_element_class', self.data_element_class)
402        return obj
403
404    def element_decode(self, data: ElementData, xsd_element: 'XsdElement',
405                       xsd_type: Optional[BaseXsdType] = None, level: int = 0) -> 'DataElement':
406        data_element = self.data_element_class(
407            tag=data.tag,
408            value=data.text,
409            nsmap=self.namespaces,
410            xsd_element=xsd_element,
411            xsd_type=xsd_type
412        )
413        data_element.attrib.update((k, v) for k, v in self.map_attributes(data.attributes))
414
415        if (xsd_type or xsd_element.type).model_group is not None:
416            for name, value, _ in self.map_content(data.content):
417                if not name.isdigit():
418                    data_element.append(value)
419                else:
420                    try:
421                        data_element[-1].tail = value
422                    except IndexError:
423                        data_element.value = value
424
425        return data_element
426
427    def element_encode(self, data_element: 'DataElement', xsd_element: 'XsdElement',
428                       level: int = 0) -> ElementData:
429        self.namespaces.update(data_element.nsmap)
430        if not xsd_element.is_matching(data_element.tag, self._namespaces.get('')):
431            raise XMLSchemaValueError("Unmatched tag")
432
433        attributes = {self.unmap_qname(k, xsd_element.attributes): v
434                      for k, v in data_element.attrib.items()}
435
436        data_len = len(data_element)
437        if not data_len:
438            return ElementData(data_element.tag, data_element.value, None, attributes)
439
440        content: List[Tuple[Union[str, int], Any]] = []
441        cdata_num = count(1)
442        if data_element.value is not None:
443            content.append((next(cdata_num), data_element.value))
444
445        for e in data_element:
446            content.append((e.tag, e))
447            if e.tail is not None:
448                content.append((next(cdata_num), e.tail))
449
450        return ElementData(data_element.tag, None, content, attributes)
451
452
453class DataBindingConverter(DataElementConverter):
454    """
455    A :class:`DataElementConverter` that uses XML data binding classes for
456    decoding. Takes the same arguments of its parent class but the argument
457    *data_element_class* is used for define the base for creating the missing
458    XML binding classes.
459    """
460    __slots__ = ()
461
462    def element_decode(self, data: ElementData, xsd_element: 'XsdElement',
463                       xsd_type: Optional[BaseXsdType] = None, level: int = 0) -> 'DataElement':
464        cls = xsd_element.get_binding(self.data_element_class)
465        data_element = cls(
466            tag=data.tag,
467            value=data.text,
468            nsmap=self.namespaces,
469            xsd_type=xsd_type
470        )
471        data_element.attrib.update((k, v) for k, v in self.map_attributes(data.attributes))
472
473        if (xsd_type or xsd_element.type).model_group is not None:
474            for name, value, _ in self.map_content(data.content):
475                if not name.isdigit():
476                    data_element.append(value)
477                else:
478                    try:
479                        data_element[-1].tail = value
480                    except IndexError:
481                        data_element.value = value
482
483        return data_element
484