1"""\
2A library of useful helper classes to the SAX classes, for the
3convenience of application and driver writers.
4"""
5
6import os, urllib.parse, urllib.request
7import io
8import codecs
9from . import handler
10from . import xmlreader
11
12def __dict_replace(s, d):
13    """Replace substrings of a string using a dictionary."""
14    for key, value in d.items():
15        s = s.replace(key, value)
16    return s
17
18def escape(data, entities={}):
19    """Escape &, <, and > in a string of data.
20
21    You can escape other strings of data by passing a dictionary as
22    the optional entities parameter.  The keys and values must all be
23    strings; each key will be replaced with its corresponding value.
24    """
25
26    # must do ampersand first
27    data = data.replace("&", "&amp;")
28    data = data.replace(">", "&gt;")
29    data = data.replace("<", "&lt;")
30    if entities:
31        data = __dict_replace(data, entities)
32    return data
33
34def unescape(data, entities={}):
35    """Unescape &amp;, &lt;, and &gt; in a string of data.
36
37    You can unescape other strings of data by passing a dictionary as
38    the optional entities parameter.  The keys and values must all be
39    strings; each key will be replaced with its corresponding value.
40    """
41    data = data.replace("&lt;", "<")
42    data = data.replace("&gt;", ">")
43    if entities:
44        data = __dict_replace(data, entities)
45    # must do ampersand last
46    return data.replace("&amp;", "&")
47
48def quoteattr(data, entities={}):
49    """Escape and quote an attribute value.
50
51    Escape &, <, and > in a string of data, then quote it for use as
52    an attribute value.  The \" character will be escaped as well, if
53    necessary.
54
55    You can escape other strings of data by passing a dictionary as
56    the optional entities parameter.  The keys and values must all be
57    strings; each key will be replaced with its corresponding value.
58    """
59    entities = {**entities, '\n': '&#10;', '\r': '&#13;', '\t':'&#9;'}
60    data = escape(data, entities)
61    if '"' in data:
62        if "'" in data:
63            data = '"%s"' % data.replace('"', "&quot;")
64        else:
65            data = "'%s'" % data
66    else:
67        data = '"%s"' % data
68    return data
69
70
71def _gettextwriter(out, encoding):
72    if out is None:
73        import sys
74        return sys.stdout
75
76    if isinstance(out, io.TextIOBase):
77        # use a text writer as is
78        return out
79
80    if isinstance(out, (codecs.StreamWriter, codecs.StreamReaderWriter)):
81        # use a codecs stream writer as is
82        return out
83
84    # wrap a binary writer with TextIOWrapper
85    if isinstance(out, io.RawIOBase):
86        # Keep the original file open when the TextIOWrapper is
87        # destroyed
88        class _wrapper:
89            __class__ = out.__class__
90            def __getattr__(self, name):
91                return getattr(out, name)
92        buffer = _wrapper()
93        buffer.close = lambda: None
94    else:
95        # This is to handle passed objects that aren't in the
96        # IOBase hierarchy, but just have a write method
97        buffer = io.BufferedIOBase()
98        buffer.writable = lambda: True
99        buffer.write = out.write
100        try:
101            # TextIOWrapper uses this methods to determine
102            # if BOM (for UTF-16, etc) should be added
103            buffer.seekable = out.seekable
104            buffer.tell = out.tell
105        except AttributeError:
106            pass
107    return io.TextIOWrapper(buffer, encoding=encoding,
108                            errors='xmlcharrefreplace',
109                            newline='\n',
110                            write_through=True)
111
112class XMLGenerator(handler.ContentHandler):
113
114    def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False):
115        handler.ContentHandler.__init__(self)
116        out = _gettextwriter(out, encoding)
117        self._write = out.write
118        self._flush = out.flush
119        self._ns_contexts = [{}] # contains uri -> prefix dicts
120        self._current_context = self._ns_contexts[-1]
121        self._undeclared_ns_maps = []
122        self._encoding = encoding
123        self._short_empty_elements = short_empty_elements
124        self._pending_start_element = False
125
126    def _qname(self, name):
127        """Builds a qualified name from a (ns_url, localname) pair"""
128        if name[0]:
129            # Per http://www.w3.org/XML/1998/namespace, The 'xml' prefix is
130            # bound by definition to http://www.w3.org/XML/1998/namespace.  It
131            # does not need to be declared and will not usually be found in
132            # self._current_context.
133            if 'http://www.w3.org/XML/1998/namespace' == name[0]:
134                return 'xml:' + name[1]
135            # The name is in a non-empty namespace
136            prefix = self._current_context[name[0]]
137            if prefix:
138                # If it is not the default namespace, prepend the prefix
139                return prefix + ":" + name[1]
140        # Return the unqualified name
141        return name[1]
142
143    def _finish_pending_start_element(self,endElement=False):
144        if self._pending_start_element:
145            self._write('>')
146            self._pending_start_element = False
147
148    # ContentHandler methods
149
150    def startDocument(self):
151        self._write('<?xml version="1.0" encoding="%s"?>\n' %
152                        self._encoding)
153
154    def endDocument(self):
155        self._flush()
156
157    def startPrefixMapping(self, prefix, uri):
158        self._ns_contexts.append(self._current_context.copy())
159        self._current_context[uri] = prefix
160        self._undeclared_ns_maps.append((prefix, uri))
161
162    def endPrefixMapping(self, prefix):
163        self._current_context = self._ns_contexts[-1]
164        del self._ns_contexts[-1]
165
166    def startElement(self, name, attrs):
167        self._finish_pending_start_element()
168        self._write('<' + name)
169        for (name, value) in attrs.items():
170            self._write(' %s=%s' % (name, quoteattr(value)))
171        if self._short_empty_elements:
172            self._pending_start_element = True
173        else:
174            self._write(">")
175
176    def endElement(self, name):
177        if self._pending_start_element:
178            self._write('/>')
179            self._pending_start_element = False
180        else:
181            self._write('</%s>' % name)
182
183    def startElementNS(self, name, qname, attrs):
184        self._finish_pending_start_element()
185        self._write('<' + self._qname(name))
186
187        for prefix, uri in self._undeclared_ns_maps:
188            if prefix:
189                self._write(' xmlns:%s="%s"' % (prefix, uri))
190            else:
191                self._write(' xmlns="%s"' % uri)
192        self._undeclared_ns_maps = []
193
194        for (name, value) in attrs.items():
195            self._write(' %s=%s' % (self._qname(name), quoteattr(value)))
196        if self._short_empty_elements:
197            self._pending_start_element = True
198        else:
199            self._write(">")
200
201    def endElementNS(self, name, qname):
202        if self._pending_start_element:
203            self._write('/>')
204            self._pending_start_element = False
205        else:
206            self._write('</%s>' % self._qname(name))
207
208    def characters(self, content):
209        if content:
210            self._finish_pending_start_element()
211            if not isinstance(content, str):
212                content = str(content, self._encoding)
213            self._write(escape(content))
214
215    def ignorableWhitespace(self, content):
216        if content:
217            self._finish_pending_start_element()
218            if not isinstance(content, str):
219                content = str(content, self._encoding)
220            self._write(content)
221
222    def processingInstruction(self, target, data):
223        self._finish_pending_start_element()
224        self._write('<?%s %s?>' % (target, data))
225
226
227class XMLFilterBase(xmlreader.XMLReader):
228    """This class is designed to sit between an XMLReader and the
229    client application's event handlers.  By default, it does nothing
230    but pass requests up to the reader and events on to the handlers
231    unmodified, but subclasses can override specific methods to modify
232    the event stream or the configuration requests as they pass
233    through."""
234
235    def __init__(self, parent = None):
236        xmlreader.XMLReader.__init__(self)
237        self._parent = parent
238
239    # ErrorHandler methods
240
241    def error(self, exception):
242        self._err_handler.error(exception)
243
244    def fatalError(self, exception):
245        self._err_handler.fatalError(exception)
246
247    def warning(self, exception):
248        self._err_handler.warning(exception)
249
250    # ContentHandler methods
251
252    def setDocumentLocator(self, locator):
253        self._cont_handler.setDocumentLocator(locator)
254
255    def startDocument(self):
256        self._cont_handler.startDocument()
257
258    def endDocument(self):
259        self._cont_handler.endDocument()
260
261    def startPrefixMapping(self, prefix, uri):
262        self._cont_handler.startPrefixMapping(prefix, uri)
263
264    def endPrefixMapping(self, prefix):
265        self._cont_handler.endPrefixMapping(prefix)
266
267    def startElement(self, name, attrs):
268        self._cont_handler.startElement(name, attrs)
269
270    def endElement(self, name):
271        self._cont_handler.endElement(name)
272
273    def startElementNS(self, name, qname, attrs):
274        self._cont_handler.startElementNS(name, qname, attrs)
275
276    def endElementNS(self, name, qname):
277        self._cont_handler.endElementNS(name, qname)
278
279    def characters(self, content):
280        self._cont_handler.characters(content)
281
282    def ignorableWhitespace(self, chars):
283        self._cont_handler.ignorableWhitespace(chars)
284
285    def processingInstruction(self, target, data):
286        self._cont_handler.processingInstruction(target, data)
287
288    def skippedEntity(self, name):
289        self._cont_handler.skippedEntity(name)
290
291    # DTDHandler methods
292
293    def notationDecl(self, name, publicId, systemId):
294        self._dtd_handler.notationDecl(name, publicId, systemId)
295
296    def unparsedEntityDecl(self, name, publicId, systemId, ndata):
297        self._dtd_handler.unparsedEntityDecl(name, publicId, systemId, ndata)
298
299    # EntityResolver methods
300
301    def resolveEntity(self, publicId, systemId):
302        return self._ent_handler.resolveEntity(publicId, systemId)
303
304    # XMLReader methods
305
306    def parse(self, source):
307        self._parent.setContentHandler(self)
308        self._parent.setErrorHandler(self)
309        self._parent.setEntityResolver(self)
310        self._parent.setDTDHandler(self)
311        self._parent.parse(source)
312
313    def setLocale(self, locale):
314        self._parent.setLocale(locale)
315
316    def getFeature(self, name):
317        return self._parent.getFeature(name)
318
319    def setFeature(self, name, state):
320        self._parent.setFeature(name, state)
321
322    def getProperty(self, name):
323        return self._parent.getProperty(name)
324
325    def setProperty(self, name, value):
326        self._parent.setProperty(name, value)
327
328    # XMLFilter methods
329
330    def getParent(self):
331        return self._parent
332
333    def setParent(self, parent):
334        self._parent = parent
335
336# --- Utility functions
337
338def prepare_input_source(source, base=""):
339    """This function takes an InputSource and an optional base URL and
340    returns a fully resolved InputSource object ready for reading."""
341
342    if isinstance(source, os.PathLike):
343        source = os.fspath(source)
344    if isinstance(source, str):
345        source = xmlreader.InputSource(source)
346    elif hasattr(source, "read"):
347        f = source
348        source = xmlreader.InputSource()
349        if isinstance(f.read(0), str):
350            source.setCharacterStream(f)
351        else:
352            source.setByteStream(f)
353        if hasattr(f, "name") and isinstance(f.name, str):
354            source.setSystemId(f.name)
355
356    if source.getCharacterStream() is None and source.getByteStream() is None:
357        sysid = source.getSystemId()
358        basehead = os.path.dirname(os.path.normpath(base))
359        sysidfilename = os.path.join(basehead, sysid)
360        if os.path.isfile(sysidfilename):
361            source.setSystemId(sysidfilename)
362            f = open(sysidfilename, "rb")
363        else:
364            source.setSystemId(urllib.parse.urljoin(base, sysid))
365            f = urllib.request.urlopen(source.getSystemId())
366
367        source.setByteStream(f)
368
369    return source
370