1#
2# Copyright (c), 2016-2020, SISSA (International School for Advanced Studies).
3# All rights reserved.
4# This file is distributed under the terms of the MIT License.
5# See the file 'LICENSE' in the root directory of the present
6# distribution, or http://opensource.org/licenses/MIT.
7#
8# @author Davide Brunato <brunato@sissa.it>
9#
10"""
11This module defines a proxy class and a mixin class for enabling XPath on schemas.
12"""
13import sys
14from abc import abstractmethod
15from typing import cast, overload, Any, Dict, Iterator, List, Optional, \
16    Sequence, Set, TypeVar, Union
17import re
18
19from elementpath import AttributeNode, TypedElement, XPath2Parser, \
20    XPathSchemaContext, AbstractSchemaProxy, protocols
21
22from .exceptions import XMLSchemaValueError, XMLSchemaTypeError
23from .names import XSD_NAMESPACE
24from .aliases import NamespacesType, SchemaType, BaseXsdType, XPathElementType
25from .helpers import get_qname, local_name, get_prefixed_qname
26
27if sys.version_info < (3, 8):
28    XMLSchemaProtocol = SchemaType
29    ElementProtocol = XPathElementType
30    XsdTypeProtocol = BaseXsdType
31else:
32    from typing import runtime_checkable, Protocol
33
34    XsdTypeProtocol = protocols.XsdTypeProtocol
35
36    class XMLSchemaProtocol(protocols.XMLSchemaProtocol, Protocol):
37        attributes: Dict[str, Any]
38
39    @runtime_checkable
40    class ElementProtocol(protocols.ElementProtocol, Protocol):
41        schema: XMLSchemaProtocol
42        attributes: Dict[str, Any]
43
44
45_REGEX_TAG_POSITION = re.compile(r'\b\[\d+]')
46
47
48def iter_schema_nodes(root: Union[XMLSchemaProtocol, ElementProtocol],
49                      with_root: bool = True,
50                      with_attributes: bool = False) \
51        -> Iterator[Union[XMLSchemaProtocol, ElementProtocol, AttributeNode]]:
52    """
53    Iteration function for schema nodes. It doesn't yield text nodes,
54    that are always `None` for schema elements, and detects visited
55    element in order to skip already visited nodes.
56
57    :param root: schema or schema element.
58    :param with_root: if `True` yields initial element.
59    :param with_attributes: if `True` yields also attribute nodes.
60    """
61    def attribute_node(x: Any) -> AttributeNode:
62        return AttributeNode(*x)
63
64    def _iter_schema_nodes(elem: Any) -> Iterator[Any]:
65        for child in elem:
66            if child in nodes:
67                continue
68            elif child.ref is not None:
69                nodes.add(child)
70                yield child
71                if child.ref not in nodes:
72                    nodes.add(child.ref)
73                    yield child.ref
74                    if with_attributes:
75                        yield from map(attribute_node, child.attributes.items())
76                    yield from _iter_schema_nodes(child.ref)
77            else:
78                nodes.add(child)
79                yield child
80                if with_attributes:
81                    yield from map(attribute_node, child.attributes.items())
82                yield from _iter_schema_nodes(child)
83
84    if isinstance(root, TypedElement):
85        root = cast(ElementProtocol, root.elem)
86
87    nodes = {root}
88    if with_root:
89        yield root
90    if with_attributes:
91        yield from map(attribute_node, root.attributes.items())
92    yield from _iter_schema_nodes(root)
93
94
95class XMLSchemaContext(XPathSchemaContext):
96    """XPath dynamic schema context for the *xmlschema* library."""
97    _iter_nodes = staticmethod(iter_schema_nodes)
98
99
100class XMLSchemaProxy(AbstractSchemaProxy):
101    """XPath schema proxy for the *xmlschema* library."""
102    _schema: SchemaType  # type: ignore[assignment]
103
104    def __init__(self, schema: Optional[XMLSchemaProtocol] = None,
105                 base_element: Optional[ElementProtocol] = None) -> None:
106
107        if schema is None:
108            from xmlschema import XMLSchema
109            schema = getattr(XMLSchema, 'meta_schema', None)
110        super(XMLSchemaProxy, self).__init__(schema, base_element)
111
112        if base_element is not None:
113            try:
114                if base_element.schema is not schema:
115                    raise XMLSchemaValueError("%r is not an element of %r" % (base_element, schema))
116            except AttributeError:
117                raise XMLSchemaTypeError("%r is not an XsdElement" % base_element)
118
119    def bind_parser(self, parser: XPath2Parser) -> None:
120        parser.schema = self
121        parser.symbol_table = dict(parser.__class__.symbol_table)
122
123        with self._schema.lock:
124            if self._schema.xpath_tokens is None:
125                self._schema.xpath_tokens = {
126                    xsd_type.name: parser.schema_constructor(xsd_type.name)
127                    for xsd_type in self.iter_atomic_types() if xsd_type.name
128                }
129
130        parser.symbol_table.update(self._schema.xpath_tokens)
131
132    def get_context(self) -> XMLSchemaContext:
133        return XMLSchemaContext(
134            root=self._schema,  # type: ignore[arg-type]
135            namespaces=dict(self._schema.namespaces),
136            item=self._base_element
137        )
138
139    def is_instance(self, obj: Any, type_qname: str) -> bool:
140        # FIXME: use elementpath.datatypes for checking atomic datatypes
141        xsd_type = self._schema.maps.types[type_qname]
142        if isinstance(xsd_type, tuple):
143            from .validators import XMLSchemaNotBuiltError
144            raise XMLSchemaNotBuiltError(xsd_type[1], f"XSD type {type_qname} is not built")
145
146        try:
147            xsd_type.encode(obj)
148        except ValueError:
149            return False
150        else:
151            return True
152
153    def cast_as(self, obj: Any, type_qname: str) -> Any:
154        xsd_type = self._schema.maps.types[type_qname]
155        if isinstance(xsd_type, tuple):
156            from .validators import XMLSchemaNotBuiltError
157            raise XMLSchemaNotBuiltError(xsd_type[1], f"XSD type {type_qname} is not built")
158        return xsd_type.decode(obj)
159
160    def iter_atomic_types(self) -> Iterator[XsdTypeProtocol]:
161        for xsd_type in self._schema.maps.types.values():
162            if not isinstance(xsd_type, tuple) and \
163                    xsd_type.target_namespace != XSD_NAMESPACE and \
164                    hasattr(xsd_type, 'primitive_type'):
165                yield cast(XsdTypeProtocol, xsd_type)
166
167    def get_primitive_type(self, xsd_type: XsdTypeProtocol) -> XsdTypeProtocol:
168        primitive_type = cast(BaseXsdType, xsd_type).root_type
169        return cast(XsdTypeProtocol, primitive_type)
170
171
172E = TypeVar('E', bound='ElementPathMixin[Any]')
173
174
175class ElementPathMixin(Sequence[E]):
176    """
177    Mixin abstract class for enabling ElementTree and XPath 2.0 API on XSD components.
178
179    :cvar text: the Element text, for compatibility with the ElementTree API.
180    :cvar tail: the Element tail, for compatibility with the ElementTree API.
181    """
182    text: Optional[str] = None
183    tail: Optional[str] = None
184    name: Optional[str] = None
185    attributes: Any = {}
186    namespaces: Any = {}
187    xpath_default_namespace = ''
188
189    @abstractmethod
190    def __iter__(self) -> Iterator[E]:
191        raise NotImplementedError
192
193    @overload
194    def __getitem__(self, i: int) -> E: ...
195
196    @overload
197    def __getitem__(self, s: slice) -> Sequence[E]: ...
198
199    def __getitem__(self, i: Union[int, slice]) -> Union[E, Sequence[E]]:
200        try:
201            return [e for e in self][i]
202        except IndexError:
203            raise IndexError('child index out of range')
204
205    def __reversed__(self) -> Iterator[E]:
206        return reversed([e for e in self])
207
208    def __len__(self) -> int:
209        return len([e for e in self])
210
211    @property
212    def tag(self) -> str:
213        """Alias of the *name* attribute. For compatibility with the ElementTree API."""
214        return self.name or ''
215
216    @property
217    def attrib(self) -> Any:
218        """Returns the Element attributes. For compatibility with the ElementTree API."""
219        return self.attributes
220
221    def get(self, key: str, default: Any = None) -> Any:
222        """Gets an Element attribute. For compatibility with the ElementTree API."""
223        return self.attributes.get(key, default)
224
225    @property
226    def xpath_proxy(self) -> XMLSchemaProxy:
227        """Returns an XPath proxy instance bound with the schema."""
228        raise NotImplementedError
229
230    def _get_xpath_namespaces(self, namespaces: Optional[NamespacesType] = None) \
231            -> Dict[str, str]:
232        """
233        Returns a dictionary with namespaces for XPath selection.
234
235        :param namespaces: an optional map from namespace prefix to namespace URI. \
236        If this argument is not provided the schema's namespaces are used.
237        """
238        if namespaces is None:
239            namespaces = {k: v for k, v in self.namespaces.items() if k}
240            namespaces[''] = self.xpath_default_namespace
241        elif '' not in namespaces:
242            namespaces[''] = self.xpath_default_namespace
243
244        xpath_namespaces: Dict[str, str] = XPath2Parser.DEFAULT_NAMESPACES.copy()
245        xpath_namespaces.update(namespaces)
246        return xpath_namespaces
247
248    def is_matching(self, name: Optional[str], default_namespace: Optional[str] = None) -> bool:
249        if not name or name[0] == '{' or not default_namespace:
250            return self.name == name
251        else:
252            return self.name == '{%s}%s' % (default_namespace, name)
253
254    def find(self, path: str, namespaces: Optional[NamespacesType] = None) -> Optional[E]:
255        """
256        Finds the first XSD subelement matching the path.
257
258        :param path: an XPath expression that considers the XSD component as the root element.
259        :param namespaces: an optional mapping from namespace prefix to namespace URI.
260        :return: the first matching XSD subelement or ``None`` if there is no match.
261        """
262        path = _REGEX_TAG_POSITION.sub('', path.strip())  # Strips tags positions from path
263        namespaces = self._get_xpath_namespaces(namespaces)
264        parser = XPath2Parser(namespaces, strict=False)
265        context = XMLSchemaContext(self)  # type: ignore[arg-type]
266
267        return cast(Optional[E], next(parser.parse(path).select_results(context), None))
268
269    def findall(self, path: str, namespaces: Optional[NamespacesType] = None) -> List[E]:
270        """
271        Finds all XSD subelements matching the path.
272
273        :param path: an XPath expression that considers the XSD component as the root element.
274        :param namespaces: an optional mapping from namespace prefix to full name.
275        :return: a list containing all matching XSD subelements in document order, an empty \
276        list is returned if there is no match.
277        """
278        path = _REGEX_TAG_POSITION.sub('', path.strip())  # Strips tags positions from path
279        namespaces = self._get_xpath_namespaces(namespaces)
280        parser = XPath2Parser(namespaces, strict=False)
281        context = XMLSchemaContext(self)  # type: ignore[arg-type]
282
283        return cast(List[E], parser.parse(path).get_results(context))
284
285    def iterfind(self, path: str, namespaces: Optional[NamespacesType] = None) -> Iterator[E]:
286        """
287        Creates and iterator for all XSD subelements matching the path.
288
289        :param path: an XPath expression that considers the XSD component as the root element.
290        :param namespaces: is an optional mapping from namespace prefix to full name.
291        :return: an iterable yielding all matching XSD subelements in document order.
292        """
293        path = _REGEX_TAG_POSITION.sub('', path.strip())  # Strips tags positions from path
294        namespaces = self._get_xpath_namespaces(namespaces)
295        parser = XPath2Parser(namespaces, strict=False)
296        context = XMLSchemaContext(self)  # type: ignore[arg-type]
297
298        return cast(Iterator[E], parser.parse(path).select_results(context))
299
300    def iter(self, tag: Optional[str] = None) -> Iterator[E]:
301        """
302        Creates an iterator for the XSD element and its subelements. If tag is not `None` or '*',
303        only XSD elements whose matches tag are returned from the iterator. Local elements are
304        expanded without repetitions. Element references are not expanded because the global
305        elements are not descendants of other elements.
306        """
307        def safe_iter(elem: Any) -> Iterator[E]:
308            if tag is None or elem.is_matching(tag):
309                yield elem
310            for child in elem:
311                if child.parent is None:
312                    yield from safe_iter(child)
313                elif getattr(child, 'ref', None) is not None:
314                    if tag is None or child.is_matching(tag):
315                        yield child
316                elif child not in local_elements:
317                    local_elements.add(child)
318                    yield from safe_iter(child)
319
320        if tag == '*':
321            tag = None
322        local_elements: Set[E] = set()
323        return safe_iter(self)
324
325    def iterchildren(self, tag: Optional[str] = None) -> Iterator[E]:
326        """
327        Creates an iterator for the child elements of the XSD component. If *tag* is not `None`
328        or '*', only XSD elements whose name matches tag are returned from the iterator.
329        """
330        if tag == '*':
331            tag = None
332        for child in self:
333            if tag is None or child.is_matching(tag):
334                yield child
335
336
337class XPathElement(ElementPathMixin['XPathElement']):
338    """An element node for making XPath operations on schema types."""
339    name: str
340    parent = None
341
342    def __init__(self, name: str, xsd_type: BaseXsdType) -> None:
343        self.name = name
344        self.type = xsd_type
345        self.attributes = getattr(xsd_type, 'attributes', {})
346
347    def __iter__(self) -> Iterator['XPathElement']:
348        if not self.type.has_simple_content():
349            yield from self.type.content.iter_elements()  # type: ignore[union-attr]
350
351    @property
352    def xpath_proxy(self) -> XMLSchemaProxy:
353        return XMLSchemaProxy(
354            cast(XMLSchemaProtocol, self.schema),
355            cast(ElementProtocol, self)
356        )
357
358    @property
359    def schema(self) -> SchemaType:
360        return self.type.schema
361
362    @property
363    def target_namespace(self) -> str:
364        return self.type.schema.target_namespace
365
366    @property
367    def namespaces(self) -> NamespacesType:
368        return self.type.schema.namespaces
369
370    @property
371    def local_name(self) -> str:
372        return local_name(self.name)
373
374    @property
375    def qualified_name(self) -> str:
376        return get_qname(self.target_namespace, self.name)
377
378    @property
379    def prefixed_name(self) -> str:
380        return get_prefixed_qname(self.name, self.namespaces)
381