1# 2# Copyright (c), 2016-2020, SISSA (International School for Advanced Studies). 3# All rights reserved. 4# This file is distributed under the terms of the MIT License. 5# See the file 'LICENSE' in the root directory of the present 6# distribution, or http://opensource.org/licenses/MIT. 7# 8# @author Davide Brunato <brunato@sissa.it> 9# 10""" 11This module defines a proxy class and a mixin class for enabling XPath on schemas. 12""" 13import sys 14from abc import abstractmethod 15from typing import cast, overload, Any, Dict, Iterator, List, Optional, \ 16 Sequence, Set, TypeVar, Union 17import re 18 19from elementpath import AttributeNode, TypedElement, XPath2Parser, \ 20 XPathSchemaContext, AbstractSchemaProxy, protocols 21 22from .exceptions import XMLSchemaValueError, XMLSchemaTypeError 23from .names import XSD_NAMESPACE 24from .aliases import NamespacesType, SchemaType, BaseXsdType, XPathElementType 25from .helpers import get_qname, local_name, get_prefixed_qname 26 27if sys.version_info < (3, 8): 28 XMLSchemaProtocol = SchemaType 29 ElementProtocol = XPathElementType 30 XsdTypeProtocol = BaseXsdType 31else: 32 from typing import runtime_checkable, Protocol 33 34 XsdTypeProtocol = protocols.XsdTypeProtocol 35 36 class XMLSchemaProtocol(protocols.XMLSchemaProtocol, Protocol): 37 attributes: Dict[str, Any] 38 39 @runtime_checkable 40 class ElementProtocol(protocols.ElementProtocol, Protocol): 41 schema: XMLSchemaProtocol 42 attributes: Dict[str, Any] 43 44 45_REGEX_TAG_POSITION = re.compile(r'\b\[\d+]') 46 47 48def iter_schema_nodes(root: Union[XMLSchemaProtocol, ElementProtocol], 49 with_root: bool = True, 50 with_attributes: bool = False) \ 51 -> Iterator[Union[XMLSchemaProtocol, ElementProtocol, AttributeNode]]: 52 """ 53 Iteration function for schema nodes. It doesn't yield text nodes, 54 that are always `None` for schema elements, and detects visited 55 element in order to skip already visited nodes. 56 57 :param root: schema or schema element. 58 :param with_root: if `True` yields initial element. 59 :param with_attributes: if `True` yields also attribute nodes. 60 """ 61 def attribute_node(x: Any) -> AttributeNode: 62 return AttributeNode(*x) 63 64 def _iter_schema_nodes(elem: Any) -> Iterator[Any]: 65 for child in elem: 66 if child in nodes: 67 continue 68 elif child.ref is not None: 69 nodes.add(child) 70 yield child 71 if child.ref not in nodes: 72 nodes.add(child.ref) 73 yield child.ref 74 if with_attributes: 75 yield from map(attribute_node, child.attributes.items()) 76 yield from _iter_schema_nodes(child.ref) 77 else: 78 nodes.add(child) 79 yield child 80 if with_attributes: 81 yield from map(attribute_node, child.attributes.items()) 82 yield from _iter_schema_nodes(child) 83 84 if isinstance(root, TypedElement): 85 root = cast(ElementProtocol, root.elem) 86 87 nodes = {root} 88 if with_root: 89 yield root 90 if with_attributes: 91 yield from map(attribute_node, root.attributes.items()) 92 yield from _iter_schema_nodes(root) 93 94 95class XMLSchemaContext(XPathSchemaContext): 96 """XPath dynamic schema context for the *xmlschema* library.""" 97 _iter_nodes = staticmethod(iter_schema_nodes) 98 99 100class XMLSchemaProxy(AbstractSchemaProxy): 101 """XPath schema proxy for the *xmlschema* library.""" 102 _schema: SchemaType # type: ignore[assignment] 103 104 def __init__(self, schema: Optional[XMLSchemaProtocol] = None, 105 base_element: Optional[ElementProtocol] = None) -> None: 106 107 if schema is None: 108 from xmlschema import XMLSchema 109 schema = getattr(XMLSchema, 'meta_schema', None) 110 super(XMLSchemaProxy, self).__init__(schema, base_element) 111 112 if base_element is not None: 113 try: 114 if base_element.schema is not schema: 115 raise XMLSchemaValueError("%r is not an element of %r" % (base_element, schema)) 116 except AttributeError: 117 raise XMLSchemaTypeError("%r is not an XsdElement" % base_element) 118 119 def bind_parser(self, parser: XPath2Parser) -> None: 120 parser.schema = self 121 parser.symbol_table = dict(parser.__class__.symbol_table) 122 123 with self._schema.lock: 124 if self._schema.xpath_tokens is None: 125 self._schema.xpath_tokens = { 126 xsd_type.name: parser.schema_constructor(xsd_type.name) 127 for xsd_type in self.iter_atomic_types() if xsd_type.name 128 } 129 130 parser.symbol_table.update(self._schema.xpath_tokens) 131 132 def get_context(self) -> XMLSchemaContext: 133 return XMLSchemaContext( 134 root=self._schema, # type: ignore[arg-type] 135 namespaces=dict(self._schema.namespaces), 136 item=self._base_element 137 ) 138 139 def is_instance(self, obj: Any, type_qname: str) -> bool: 140 # FIXME: use elementpath.datatypes for checking atomic datatypes 141 xsd_type = self._schema.maps.types[type_qname] 142 if isinstance(xsd_type, tuple): 143 from .validators import XMLSchemaNotBuiltError 144 raise XMLSchemaNotBuiltError(xsd_type[1], f"XSD type {type_qname} is not built") 145 146 try: 147 xsd_type.encode(obj) 148 except ValueError: 149 return False 150 else: 151 return True 152 153 def cast_as(self, obj: Any, type_qname: str) -> Any: 154 xsd_type = self._schema.maps.types[type_qname] 155 if isinstance(xsd_type, tuple): 156 from .validators import XMLSchemaNotBuiltError 157 raise XMLSchemaNotBuiltError(xsd_type[1], f"XSD type {type_qname} is not built") 158 return xsd_type.decode(obj) 159 160 def iter_atomic_types(self) -> Iterator[XsdTypeProtocol]: 161 for xsd_type in self._schema.maps.types.values(): 162 if not isinstance(xsd_type, tuple) and \ 163 xsd_type.target_namespace != XSD_NAMESPACE and \ 164 hasattr(xsd_type, 'primitive_type'): 165 yield cast(XsdTypeProtocol, xsd_type) 166 167 def get_primitive_type(self, xsd_type: XsdTypeProtocol) -> XsdTypeProtocol: 168 primitive_type = cast(BaseXsdType, xsd_type).root_type 169 return cast(XsdTypeProtocol, primitive_type) 170 171 172E = TypeVar('E', bound='ElementPathMixin[Any]') 173 174 175class ElementPathMixin(Sequence[E]): 176 """ 177 Mixin abstract class for enabling ElementTree and XPath 2.0 API on XSD components. 178 179 :cvar text: the Element text, for compatibility with the ElementTree API. 180 :cvar tail: the Element tail, for compatibility with the ElementTree API. 181 """ 182 text: Optional[str] = None 183 tail: Optional[str] = None 184 name: Optional[str] = None 185 attributes: Any = {} 186 namespaces: Any = {} 187 xpath_default_namespace = '' 188 189 @abstractmethod 190 def __iter__(self) -> Iterator[E]: 191 raise NotImplementedError 192 193 @overload 194 def __getitem__(self, i: int) -> E: ... 195 196 @overload 197 def __getitem__(self, s: slice) -> Sequence[E]: ... 198 199 def __getitem__(self, i: Union[int, slice]) -> Union[E, Sequence[E]]: 200 try: 201 return [e for e in self][i] 202 except IndexError: 203 raise IndexError('child index out of range') 204 205 def __reversed__(self) -> Iterator[E]: 206 return reversed([e for e in self]) 207 208 def __len__(self) -> int: 209 return len([e for e in self]) 210 211 @property 212 def tag(self) -> str: 213 """Alias of the *name* attribute. For compatibility with the ElementTree API.""" 214 return self.name or '' 215 216 @property 217 def attrib(self) -> Any: 218 """Returns the Element attributes. For compatibility with the ElementTree API.""" 219 return self.attributes 220 221 def get(self, key: str, default: Any = None) -> Any: 222 """Gets an Element attribute. For compatibility with the ElementTree API.""" 223 return self.attributes.get(key, default) 224 225 @property 226 def xpath_proxy(self) -> XMLSchemaProxy: 227 """Returns an XPath proxy instance bound with the schema.""" 228 raise NotImplementedError 229 230 def _get_xpath_namespaces(self, namespaces: Optional[NamespacesType] = None) \ 231 -> Dict[str, str]: 232 """ 233 Returns a dictionary with namespaces for XPath selection. 234 235 :param namespaces: an optional map from namespace prefix to namespace URI. \ 236 If this argument is not provided the schema's namespaces are used. 237 """ 238 if namespaces is None: 239 namespaces = {k: v for k, v in self.namespaces.items() if k} 240 namespaces[''] = self.xpath_default_namespace 241 elif '' not in namespaces: 242 namespaces[''] = self.xpath_default_namespace 243 244 xpath_namespaces: Dict[str, str] = XPath2Parser.DEFAULT_NAMESPACES.copy() 245 xpath_namespaces.update(namespaces) 246 return xpath_namespaces 247 248 def is_matching(self, name: Optional[str], default_namespace: Optional[str] = None) -> bool: 249 if not name or name[0] == '{' or not default_namespace: 250 return self.name == name 251 else: 252 return self.name == '{%s}%s' % (default_namespace, name) 253 254 def find(self, path: str, namespaces: Optional[NamespacesType] = None) -> Optional[E]: 255 """ 256 Finds the first XSD subelement matching the path. 257 258 :param path: an XPath expression that considers the XSD component as the root element. 259 :param namespaces: an optional mapping from namespace prefix to namespace URI. 260 :return: the first matching XSD subelement or ``None`` if there is no match. 261 """ 262 path = _REGEX_TAG_POSITION.sub('', path.strip()) # Strips tags positions from path 263 namespaces = self._get_xpath_namespaces(namespaces) 264 parser = XPath2Parser(namespaces, strict=False) 265 context = XMLSchemaContext(self) # type: ignore[arg-type] 266 267 return cast(Optional[E], next(parser.parse(path).select_results(context), None)) 268 269 def findall(self, path: str, namespaces: Optional[NamespacesType] = None) -> List[E]: 270 """ 271 Finds all XSD subelements matching the path. 272 273 :param path: an XPath expression that considers the XSD component as the root element. 274 :param namespaces: an optional mapping from namespace prefix to full name. 275 :return: a list containing all matching XSD subelements in document order, an empty \ 276 list is returned if there is no match. 277 """ 278 path = _REGEX_TAG_POSITION.sub('', path.strip()) # Strips tags positions from path 279 namespaces = self._get_xpath_namespaces(namespaces) 280 parser = XPath2Parser(namespaces, strict=False) 281 context = XMLSchemaContext(self) # type: ignore[arg-type] 282 283 return cast(List[E], parser.parse(path).get_results(context)) 284 285 def iterfind(self, path: str, namespaces: Optional[NamespacesType] = None) -> Iterator[E]: 286 """ 287 Creates and iterator for all XSD subelements matching the path. 288 289 :param path: an XPath expression that considers the XSD component as the root element. 290 :param namespaces: is an optional mapping from namespace prefix to full name. 291 :return: an iterable yielding all matching XSD subelements in document order. 292 """ 293 path = _REGEX_TAG_POSITION.sub('', path.strip()) # Strips tags positions from path 294 namespaces = self._get_xpath_namespaces(namespaces) 295 parser = XPath2Parser(namespaces, strict=False) 296 context = XMLSchemaContext(self) # type: ignore[arg-type] 297 298 return cast(Iterator[E], parser.parse(path).select_results(context)) 299 300 def iter(self, tag: Optional[str] = None) -> Iterator[E]: 301 """ 302 Creates an iterator for the XSD element and its subelements. If tag is not `None` or '*', 303 only XSD elements whose matches tag are returned from the iterator. Local elements are 304 expanded without repetitions. Element references are not expanded because the global 305 elements are not descendants of other elements. 306 """ 307 def safe_iter(elem: Any) -> Iterator[E]: 308 if tag is None or elem.is_matching(tag): 309 yield elem 310 for child in elem: 311 if child.parent is None: 312 yield from safe_iter(child) 313 elif getattr(child, 'ref', None) is not None: 314 if tag is None or child.is_matching(tag): 315 yield child 316 elif child not in local_elements: 317 local_elements.add(child) 318 yield from safe_iter(child) 319 320 if tag == '*': 321 tag = None 322 local_elements: Set[E] = set() 323 return safe_iter(self) 324 325 def iterchildren(self, tag: Optional[str] = None) -> Iterator[E]: 326 """ 327 Creates an iterator for the child elements of the XSD component. If *tag* is not `None` 328 or '*', only XSD elements whose name matches tag are returned from the iterator. 329 """ 330 if tag == '*': 331 tag = None 332 for child in self: 333 if tag is None or child.is_matching(tag): 334 yield child 335 336 337class XPathElement(ElementPathMixin['XPathElement']): 338 """An element node for making XPath operations on schema types.""" 339 name: str 340 parent = None 341 342 def __init__(self, name: str, xsd_type: BaseXsdType) -> None: 343 self.name = name 344 self.type = xsd_type 345 self.attributes = getattr(xsd_type, 'attributes', {}) 346 347 def __iter__(self) -> Iterator['XPathElement']: 348 if not self.type.has_simple_content(): 349 yield from self.type.content.iter_elements() # type: ignore[union-attr] 350 351 @property 352 def xpath_proxy(self) -> XMLSchemaProxy: 353 return XMLSchemaProxy( 354 cast(XMLSchemaProtocol, self.schema), 355 cast(ElementProtocol, self) 356 ) 357 358 @property 359 def schema(self) -> SchemaType: 360 return self.type.schema 361 362 @property 363 def target_namespace(self) -> str: 364 return self.type.schema.target_namespace 365 366 @property 367 def namespaces(self) -> NamespacesType: 368 return self.type.schema.namespaces 369 370 @property 371 def local_name(self) -> str: 372 return local_name(self.name) 373 374 @property 375 def qualified_name(self) -> str: 376 return get_qname(self.target_namespace, self.name) 377 378 @property 379 def prefixed_name(self) -> str: 380 return get_prefixed_qname(self.name, self.namespaces) 381