1#
2# Copyright (c), 2016-2021, SISSA (International School for Advanced Studies).
3# All rights reserved.
4# This file is distributed under the terms of the MIT License.
5# See the file 'LICENSE' in the root directory of the present
6# distribution, or http://opensource.org/licenses/MIT.
7#
8# @author Davide Brunato <brunato@sissa.it>
9#
10import re
11from collections import Counter
12from decimal import Decimal
13from typing import Any, Callable, Iterator, List, MutableMapping, \
14    Optional, Tuple, Union
15from .exceptions import XMLSchemaValueError, XMLSchemaTypeError
16from .names import XSI_SCHEMA_LOCATION, XSI_NONS_SCHEMA_LOCATION
17from .aliases import ElementType, NamespacesType, AtomicValueType, NumericValueType
18
19###
20# Helper functions for QNames
21
22NAMESPACE_PATTERN = re.compile(r'{([^}]*)}')
23
24
25def get_namespace(qname: str, namespaces: Optional[NamespacesType] = None) -> str:
26    """
27    Returns the namespace URI associated with a QName. If a namespace map is
28    provided tries to resolve a prefixed QName and then to extract the namespace.
29
30    :param qname: an extended QName or a local name or a prefixed QName.
31    :param namespaces: optional mapping from prefixes to namespace URIs.
32    """
33    if not qname:
34        return ''
35    elif qname[0] != '{':
36        if namespaces is None:
37            return ''
38        qname = get_extended_qname(qname, namespaces)
39
40    try:
41        return NAMESPACE_PATTERN.match(qname).group(1)  # type: ignore[union-attr]
42    except (AttributeError, TypeError):
43        return ''
44
45
46def get_qname(uri: Optional[str], name: str) -> str:
47    """
48    Returns an expanded QName from URI and local part. If any argument has boolean value
49    `False` or if the name is already an expanded QName, returns the *name* argument.
50
51    :param uri: namespace URI
52    :param name: local or qualified name
53    :return: string or the name argument
54    """
55    if not uri or not name or name[0] in ('{', '.', '/', '['):
56        return name
57    else:
58        return '{%s}%s' % (uri, name)
59
60
61def local_name(qname: str) -> str:
62    """
63    Return the local part of an expanded QName or a prefixed name. If the name
64    is `None` or empty returns the *name* argument.
65
66    :param qname: an expanded QName or a prefixed name or a local name.
67    """
68    try:
69        if qname[0] == '{':
70            _, qname = qname.split('}')
71        elif ':' in qname:
72            _, qname = qname.split(':')
73    except IndexError:
74        return ''
75    except ValueError:
76        raise XMLSchemaValueError("the argument 'qname' has a wrong format: %r" % qname)
77    except TypeError:
78        raise XMLSchemaTypeError("the argument 'qname' must be a string")
79    else:
80        return qname
81
82
83def get_prefixed_qname(qname: str,
84                       namespaces: Optional[MutableMapping[str, str]],
85                       use_empty: bool = True) -> str:
86    """
87    Get the prefixed form of a QName, using a namespace map.
88
89    :param qname: an extended QName or a local name or a prefixed QName.
90    :param namespaces: an optional mapping from prefixes to namespace URIs.
91    :param use_empty: if `True` use the empty prefix for mapping.
92    """
93    if not namespaces or not qname or qname[0] != '{':
94        return qname
95
96    namespace = get_namespace(qname)
97    prefixes = [x for x in namespaces if namespaces[x] == namespace]
98
99    if not prefixes:
100        return qname
101    elif prefixes[0]:
102        return '%s:%s' % (prefixes[0], qname.split('}', 1)[1])
103    elif len(prefixes) > 1:
104        return '%s:%s' % (prefixes[1], qname.split('}', 1)[1])
105    elif use_empty:
106        return qname.split('}', 1)[1]
107    else:
108        return qname
109
110
111def get_extended_qname(qname: str, namespaces: Optional[MutableMapping[str, str]]) -> str:
112    """
113    Get the extended form of a QName, using a namespace map.
114    Local names are mapped to the default namespace.
115
116    :param qname: a prefixed QName or a local name or an extended QName.
117    :param namespaces: an optional mapping from prefixes to namespace URIs.
118    """
119    if not namespaces:
120        return qname
121
122    try:
123        if qname[0] == '{':
124            return qname
125    except IndexError:
126        return qname
127
128    try:
129        prefix, name = qname.split(':', 1)
130    except ValueError:
131        if not namespaces.get(''):
132            return qname
133        else:
134            return '{%s}%s' % (namespaces[''], qname)
135    else:
136        try:
137            uri = namespaces[prefix]
138        except KeyError:
139            return qname
140        else:
141            return '{%s}%s' % (uri, name) if uri else name
142
143
144###
145# Helper functions for ElementTree structures
146
147def is_etree_element(obj: object) -> bool:
148    """A checker for valid ElementTree elements that excludes XsdElement objects."""
149    return hasattr(obj, 'append') and hasattr(obj, 'tag') and hasattr(obj, 'attrib')
150
151
152def is_etree_document(obj: object) -> bool:
153    """A checker for valid ElementTree objects."""
154    return hasattr(obj, 'getroot') and hasattr(obj, 'parse') and hasattr(obj, 'iter')
155
156
157def etree_iterpath(elem: ElementType,
158                   tag: Optional[str] = None,
159                   path: str = '.',
160                   namespaces: Optional[NamespacesType] = None,
161                   add_position: bool = False) -> Iterator[Tuple[ElementType, str]]:
162    """
163    Creates an iterator for the element and its subelements that yield elements and paths.
164    If tag is not `None` or '*', only elements whose matches tag are returned from the iterator.
165
166    :param elem: the element to iterate.
167    :param tag: tag filtering.
168    :param path: the current path, '.' for default.
169    :param namespaces: is an optional mapping from namespace prefix to URI.
170    :param add_position: add context position to child elements that appear multiple times.
171    """
172    if tag == "*":
173        tag = None
174    if not path:
175        path = '.'
176    if tag is None or elem.tag == tag:
177        yield elem, path
178
179    if add_position:
180        children_tags = Counter(e.tag for e in elem)
181        positions = Counter(t for t in children_tags if children_tags[t] > 1)
182    else:
183        positions = Counter()
184
185    for child in elem:
186        if callable(child.tag):
187            continue  # Skip lxml comments
188
189        child_name = child.tag if namespaces is None else get_prefixed_qname(child.tag, namespaces)
190        if path == '/':
191            child_path = '/%s' % child_name
192        else:
193            child_path = '/'.join((path, child_name))
194
195        if child.tag in positions:
196            child_path += '[%d]' % positions[child.tag]
197            positions[child.tag] += 1
198
199        yield from etree_iterpath(child, tag, child_path, namespaces, add_position)
200
201
202def etree_getpath(elem: ElementType,
203                  root: ElementType,
204                  namespaces: Optional[NamespacesType] = None,
205                  relative: bool = True,
206                  add_position: bool = False,
207                  parent_path: bool = False) -> Optional[str]:
208    """
209    Returns the XPath path from *root* to descendant *elem* element.
210
211    :param elem: the descendant element.
212    :param root: the root element.
213    :param namespaces: an optional mapping from namespace prefix to URI.
214    :param relative: returns a relative path.
215    :param add_position: add context position to child elements that appear multiple times.
216    :param parent_path: if set to `True` returns the parent path. Default is `False`.
217    :return: An XPath expression or `None` if *elem* is not a descendant of *root*.
218    """
219    if relative:
220        path = '.'
221    elif namespaces:
222        path = '/%s' % get_prefixed_qname(root.tag, namespaces)
223    else:
224        path = '/%s' % root.tag
225
226    if not parent_path:
227        for e, path in etree_iterpath(root, elem.tag, path, namespaces, add_position):
228            if e is elem:
229                return path
230    else:
231        for e, path in etree_iterpath(root, None, path, namespaces, add_position):
232            if elem in e:
233                return path
234    return None
235
236
237def etree_iter_location_hints(elem: ElementType) -> Iterator[Tuple[Any, Any]]:
238    """Yields schema location hints contained in the attributes of an element."""
239    if XSI_SCHEMA_LOCATION in elem.attrib:
240        locations = elem.attrib[XSI_SCHEMA_LOCATION].split()
241        for ns, url in zip(locations[0::2], locations[1::2]):
242            yield ns, url
243
244    if XSI_NONS_SCHEMA_LOCATION in elem.attrib:
245        for url in elem.attrib[XSI_NONS_SCHEMA_LOCATION].split():
246            yield '', url
247
248
249def prune_etree(root: ElementType, selector: Callable[[ElementType], bool]) \
250        -> Optional[bool]:
251    """
252    Removes from an tree structure the elements that verify the selector
253    function. The checking and eventual removals are performed using a
254    breadth-first visit method.
255
256    :param root: the root element of the tree.
257    :param selector: the single argument function to apply on each visited node.
258    :return: `True` if the root node verify the selector function, `None` otherwise.
259    """
260    def _prune_subtree(elem: ElementType) -> None:
261        for child in elem[:]:
262            if selector(child):
263                elem.remove(child)
264
265        for child in elem:
266            _prune_subtree(child)
267
268    if selector(root):
269        del root[:]
270        return True
271    _prune_subtree(root)
272    return None
273
274
275def count_digits(number: NumericValueType) -> Tuple[int, int]:
276    """
277    Counts the digits of a number.
278
279    :param number: an int or a float or a Decimal or a string representing a number.
280    :return: a couple with the number of digits of the integer part and \
281    the number of digits of the decimal part.
282    """
283    if isinstance(number, str):
284        number = str(Decimal(number)).lstrip('-+')
285    elif isinstance(number, bytes):
286        number = str(Decimal(number.decode())).lstrip('-+')
287    else:
288        number = str(number).lstrip('-+')
289
290    if 'E' in number:
291        significand, _, _exponent = number.partition('E')
292    elif 'e' in number:
293        significand, _, _exponent = number.partition('e')
294    elif '.' not in number:
295        return len(number.lstrip('0')), 0
296    else:
297        integer_part, _, decimal_part = number.partition('.')
298        return len(integer_part.lstrip('0')), len(decimal_part.rstrip('0'))
299
300    significand = significand.strip('0')
301    exponent = int(_exponent)
302
303    num_digits = len(significand) - 1 if '.' in significand else len(significand)
304    if exponent > 0:
305        return num_digits + exponent, 0
306    else:
307        return 0, num_digits - exponent - 1
308
309
310def strictly_equal(obj1: object, obj2: object) -> bool:
311    """Checks if the objects are equal and are of the same type."""
312    return obj1 == obj2 and type(obj1) is type(obj2)
313
314
315def raw_xml_encode(value: Union[None, AtomicValueType, List[AtomicValueType],
316                                Tuple[AtomicValueType, ...]]) -> Optional[str]:
317    """Encodes a simple value to XML."""
318    if isinstance(value, bool):
319        return 'true' if value else 'false'
320    elif isinstance(value, (list, tuple)):
321        return ' '.join(str(e) for e in value)
322    else:
323        return str(value) if value is not None else None
324