1import os.path
2import typing
3from urllib.parse import urljoin, urlparse, urlunparse
4
5from lxml import etree
6from lxml.etree import Resolver, XMLParser, XMLSyntaxError, fromstring
7
8from zeep.exceptions import DTDForbidden, EntitiesForbidden, XMLSyntaxError
9from zeep.settings import Settings
10
11
12class ImportResolver(Resolver):
13    """Custom lxml resolve to use the transport object"""
14
15    def __init__(self, transport):
16        self.transport = transport
17
18    def resolve(self, url, pubid, context):
19        if urlparse(url).scheme in ("http", "https"):
20            content = self.transport.load(url)
21            return self.resolve_string(content, context)
22
23
24def parse_xml(content: str, transport, base_url=None, settings=None):
25    """Parse an XML string and return the root Element.
26
27    :param content: The XML string
28    :type content: str
29    :param transport: The transport instance to load imported documents
30    :type transport: zeep.transports.Transport
31    :param base_url: The base url of the document, used to make relative
32      lookups absolute.
33    :type base_url: str
34    :param settings: A zeep.settings.Settings object containing parse settings.
35    :type settings: zeep.settings.Settings
36    :returns: The document root
37    :rtype: lxml.etree._Element
38
39    """
40    settings = settings or Settings()
41    recover = not settings.strict
42    parser = XMLParser(
43        remove_comments=True,
44        resolve_entities=False,
45        recover=recover,
46        huge_tree=settings.xml_huge_tree,
47    )
48    parser.resolvers.add(ImportResolver(transport))
49    try:
50        elementtree = fromstring(content, parser=parser, base_url=base_url)
51        docinfo = elementtree.getroottree().docinfo
52        if docinfo.doctype:
53            if settings.forbid_dtd:
54                raise DTDForbidden(
55                    docinfo.doctype, docinfo.system_url, docinfo.public_id
56                )
57        if settings.forbid_entities:
58            for dtd in docinfo.internalDTD, docinfo.externalDTD:
59                if dtd is None:
60                    continue
61                for entity in dtd.iterentities():
62                    raise EntitiesForbidden(entity.name, entity.content)
63
64        return elementtree
65    except etree.XMLSyntaxError as exc:
66        raise XMLSyntaxError(
67            "Invalid XML content received (%s)" % exc.msg, content=content
68        )
69
70
71def load_external(url: typing.IO, transport, base_url=None, settings=None):
72    """Load an external XML document.
73
74    :param url:
75    :param transport:
76    :param base_url:
77    :param settings: A zeep.settings.Settings object containing parse settings.
78    :type settings: zeep.settings.Settings
79
80    """
81    settings = settings or Settings()
82    if hasattr(url, "read"):
83        content = url.read()
84    else:
85        if base_url:
86            url = absolute_location(url, base_url)
87        content = transport.load(url)
88    return parse_xml(content, transport, base_url, settings=settings)
89
90
91async def load_external_async(url: typing.IO, transport, base_url=None, settings=None):
92    """Load an external XML document.
93
94    :param url:
95    :param transport:
96    :param base_url:
97    :param settings: A zeep.settings.Settings object containing parse settings.
98    :type settings: zeep.settings.Settings
99
100    """
101    settings = settings or Settings()
102    if hasattr(url, "read"):
103        content = url.read()
104    else:
105        if base_url:
106            url = absolute_location(url, base_url)
107        content = await transport.load(url)
108    return parse_xml(content, transport, base_url, settings=settings)
109
110
111def normalize_location(settings, url, base_url):
112    """Return a 'normalized' url for the given url.
113
114    This will make the url absolute and force it to https when that setting is
115    enabled.
116
117    """
118    if base_url:
119        url = absolute_location(url, base_url)
120
121    if base_url and settings.force_https:
122        base_url_parts = urlparse(base_url)
123        url_parts = urlparse(url)
124        if (
125            base_url_parts.netloc == url_parts.netloc
126            and base_url_parts.scheme != url_parts.scheme
127        ):
128            url = urlunparse(("https",) + url_parts[1:])
129    return url
130
131
132def absolute_location(location, base):
133    """Make an url absolute (if it is optional) via the passed base url.
134
135    :param location: The (relative) url
136    :type location: str
137    :param base: The base location
138    :type base: str
139    :returns: An absolute URL
140    :rtype: str
141
142    """
143    if location == base:
144        return location
145
146    if urlparse(location).scheme in ("http", "https", "file"):
147        return location
148
149    if base and urlparse(base).scheme in ("http", "https", "file"):
150        return urljoin(base, location)
151    else:
152        if os.path.isabs(location):
153            return location
154        if base:
155            return os.path.realpath(os.path.join(os.path.dirname(base), location))
156    return location
157
158
159def is_relative_path(value):
160    """Check if the given value is a relative path
161
162    :param value: The value
163    :type value: str
164    :returns: Boolean indicating if the url is relative. If it is absolute then
165      False is returned.
166    :rtype: boolean
167
168    """
169    if urlparse(value).scheme in ("http", "https", "file"):
170        return False
171    return not os.path.isabs(value)
172