1#!/usr/bin/python
2# -*- coding: utf-8 -*-
3# This program is free software; you can redistribute it and/or modify
4# it under the terms of the GNU Lesser General Public License as published by the
5# Free Software Foundation; either version 3, or (at your option) any later
6# version.
7#
8# This program is distributed in the hope that it will be useful, but
9# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTIBILITY
10# or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
11# for more details.
12
13"""Simple XML manipulation"""
14
15
16from __future__ import unicode_literals
17import sys
18if sys.version > '3':
19    basestring = str
20    unicode = str
21
22import logging
23import re
24import time
25import xml.dom.minidom
26
27from . import __author__, __copyright__, __license__, __version__
28
29# Utility functions used for marshalling, moved aside for readability
30from .helpers import TYPE_MAP, TYPE_MARSHAL_FN, TYPE_UNMARSHAL_FN, \
31                     REVERSE_TYPE_MAP, Struct, Date, Decimal
32
33log = logging.getLogger(__name__)
34
35
36class SimpleXMLElement(object):
37    """Simple XML manipulation (simil PHP)"""
38
39    def __init__(self, text=None, elements=None, document=None,
40                 namespace=None, prefix=None, namespaces_map={}, jetty=False):
41        """
42        :param namespaces_map: How to map our namespace prefix to that given by the client;
43          {prefix: received_prefix}
44        """
45        self.__namespaces_map = namespaces_map
46        _rx = "|".join(namespaces_map.keys())  # {'external': 'ext', 'model': 'mod'} -> 'external|model'
47        self.__ns_rx = re.compile(r"^(%s):.*$" % _rx)  # And now we build an expression ^(external|model):.*$
48                                                       # to find prefixes in all xml nodes i.e.: <model:code>1</model:code>
49                                                       # and later change that to <mod:code>1</mod:code>
50        self.__ns = namespace
51        self.__prefix = prefix
52        self.__jetty = jetty                           # special list support
53
54        if text is not None:
55            try:
56                self.__document = xml.dom.minidom.parseString(text)
57            except:
58                log.error(text)
59                raise
60            self.__elements = [self.__document.documentElement]
61        else:
62            self.__elements = elements
63            self.__document = document
64
65    def add_child(self, name, text=None, ns=True):
66        """Adding a child tag to a node"""
67        if not ns or self.__ns is False:
68            ##log.debug('adding %s without namespace', name)
69            element = self.__document.createElement(name)
70        else:
71            ##log.debug('adding %s ns "%s" %s', name, self.__ns, ns)
72            if isinstance(ns, basestring):
73                element = self.__document.createElement(name)
74                if ns:
75                    element.setAttribute("xmlns", ns)
76            elif self.__prefix:
77                element = self.__document.createElementNS(self.__ns, "%s:%s" % (self.__prefix, name))
78            else:
79                element = self.__document.createElementNS(self.__ns, name)
80        # don't append null tags!
81        if text is not None:
82            if isinstance(text, xml.dom.minidom.CDATASection):
83                element.appendChild(self.__document.createCDATASection(text.data))
84            else:
85                element.appendChild(self.__document.createTextNode(text))
86        self._element.appendChild(element)
87        return SimpleXMLElement(
88            elements=[element],
89            document=self.__document,
90            namespace=self.__ns,
91            prefix=self.__prefix,
92            jetty=self.__jetty,
93            namespaces_map=self.__namespaces_map
94        )
95
96    def __setattr__(self, tag, text):
97        """Add text child tag node (short form)"""
98        if tag.startswith("_"):
99            object.__setattr__(self, tag, text)
100        else:
101            ##log.debug('__setattr__(%s, %s)', tag, text)
102            self.add_child(tag, text)
103
104    def __delattr__(self, tag):
105        """Remove a child tag (non recursive!)"""
106        elements = [__element for __element in self._element.childNodes
107                    if __element.nodeType == __element.ELEMENT_NODE]
108        for element in elements:
109            self._element.removeChild(element)
110
111    def add_comment(self, data):
112        """Add an xml comment to this child"""
113        comment = self.__document.createComment(data)
114        self._element.appendChild(comment)
115
116    def as_xml(self, filename=None, pretty=False):
117        """Return the XML representation of the document"""
118        if not pretty:
119            return self.__document.toxml('UTF-8')
120        else:
121            return self.__document.toprettyxml(encoding='UTF-8')
122
123    if sys.version > '3':
124        def __repr__(self):
125            """Return the XML representation of this tag"""
126            return self._element.toxml()
127    else:
128        def __repr__(self):
129            """Return the XML representation of this tag"""
130            # NOTE: do not use self.as_xml('UTF-8') as it returns the whole xml doc
131            return self._element.toxml('UTF-8')
132
133    def get_name(self):
134        """Return the tag name of this node"""
135        return self._element.tagName
136
137    def get_local_name(self):
138        """Return the tag local name (prefix:name) of this node"""
139        return self._element.localName
140
141    def get_prefix(self):
142        """Return the namespace prefix of this node"""
143        return self._element.prefix
144
145    def get_namespace_uri(self, ns):
146        """Return the namespace uri for a prefix"""
147        element = self._element
148        while element is not None and element.attributes is not None:
149            try:
150                return element.attributes['xmlns:%s' % ns].value
151            except KeyError:
152                element = element.parentNode
153
154    def attributes(self):
155        """Return a dict of attributes for this tag"""
156        #TODO: use slice syntax [:]?
157        return self._element.attributes
158
159    def __getitem__(self, item):
160        """Return xml tag attribute value or a slice of attributes (iter)"""
161        ##log.debug('__getitem__(%s)', item)
162        if isinstance(item, basestring):
163            if self._element.hasAttribute(item):
164                return self._element.attributes[item].value
165        elif isinstance(item, slice):
166            # return a list with name:values
167            return list(self._element.attributes.items())[item]
168        else:
169            # return element by index (position)
170            element = self.__elements[item]
171            return SimpleXMLElement(
172                elements=[element],
173                document=self.__document,
174                namespace=self.__ns,
175                prefix=self.__prefix,
176                jetty=self.__jetty,
177                namespaces_map=self.__namespaces_map
178            )
179
180    def add_attribute(self, name, value):
181        """Set an attribute value from a string"""
182        self._element.setAttribute(name, value)
183
184    def __setitem__(self, item, value):
185        """Set an attribute value"""
186        if isinstance(item, basestring):
187            self.add_attribute(item, value)
188        elif isinstance(item, slice):
189            # set multiple attributes at once
190            for k, v in value.items():
191                self.add_attribute(k, v)
192
193    def __delitem__(self, item):
194        "Remove an attribute"
195        self._element.removeAttribute(item)
196
197    def __call__(self, tag=None, ns=None, children=False, root=False,
198                 error=True, ):
199        """Search (even in child nodes) and return a child tag by name"""
200        try:
201            if root:
202                # return entire document
203                return SimpleXMLElement(
204                    elements=[self.__document.documentElement],
205                    document=self.__document,
206                    namespace=self.__ns,
207                    prefix=self.__prefix,
208                    jetty=self.__jetty,
209                    namespaces_map=self.__namespaces_map
210                )
211            if tag is None:
212                # if no name given, iterate over siblings (same level)
213                return self.__iter__()
214            if children:
215                # future: filter children? by ns?
216                return self.children()
217            elements = None
218            if isinstance(tag, int):
219                # return tag by index
220                elements = [self.__elements[tag]]
221            if ns and not elements:
222                for ns_uri in isinstance(ns, (tuple, list)) and ns or (ns, ):
223                    ##log.debug('searching %s by ns=%s', tag, ns_uri)
224                    elements = self._element.getElementsByTagNameNS(ns_uri, tag)
225                    if elements:
226                        break
227            if self.__ns and not elements:
228                ##log.debug('searching %s by ns=%s', tag, self.__ns)
229                elements = self._element.getElementsByTagNameNS(self.__ns, tag)
230            if not elements:
231                ##log.debug('searching %s', tag)
232                elements = self._element.getElementsByTagName(tag)
233            if not elements:
234                ##log.debug(self._element.toxml())
235                if error:
236                    raise AttributeError("No elements found")
237                else:
238                    return
239            return SimpleXMLElement(
240                elements=elements,
241                document=self.__document,
242                namespace=self.__ns,
243                prefix=self.__prefix,
244                jetty=self.__jetty,
245                namespaces_map=self.__namespaces_map)
246        except AttributeError as e:
247            raise AttributeError("Tag not found: %s (%s)" % (tag, e))
248
249    def __getattr__(self, tag):
250        """Shortcut for __call__"""
251        return self.__call__(tag)
252
253    def __iter__(self):
254        """Iterate over xml tags at this level"""
255        try:
256            for __element in self.__elements:
257                yield SimpleXMLElement(
258                    elements=[__element],
259                    document=self.__document,
260                    namespace=self.__ns,
261                    prefix=self.__prefix,
262                    jetty=self.__jetty,
263                    namespaces_map=self.__namespaces_map)
264        except:
265            raise
266
267    def __dir__(self):
268        """List xml children tags names"""
269        return [node.tagName for node
270                in self._element.childNodes
271                if node.nodeType != node.TEXT_NODE]
272
273    def children(self):
274        """Return xml children tags element"""
275        elements = [__element for __element in self._element.childNodes
276                    if __element.nodeType == __element.ELEMENT_NODE]
277        if not elements:
278            return None
279            #raise IndexError("Tag %s has no children" % self._element.tagName)
280        return SimpleXMLElement(
281            elements=elements,
282            document=self.__document,
283            namespace=self.__ns,
284            prefix=self.__prefix,
285            jetty=self.__jetty,
286            namespaces_map=self.__namespaces_map
287        )
288
289    def __len__(self):
290        """Return element count"""
291        return len(self.__elements)
292
293    def __contains__(self, item):
294        """Search for a tag name in this element or child nodes"""
295        return self._element.getElementsByTagName(item)
296
297    def __unicode__(self):
298        """Returns the unicode text nodes of the current element"""
299        rc = ''
300        for node in self._element.childNodes:
301            if node.nodeType == node.TEXT_NODE or node.nodeType == node.CDATA_SECTION_NODE:
302                rc = rc + node.data
303        return rc
304
305    if sys.version > '3':
306        __str__ = __unicode__
307    else:
308        def __str__(self):
309            return self.__unicode__().encode('utf-8')
310
311    def __int__(self):
312        """Returns the integer value of the current element"""
313        return int(self.__str__())
314
315    def __float__(self):
316        """Returns the float value of the current element"""
317        try:
318            return float(self.__str__())
319        except:
320            raise IndexError(self._element.toxml())
321
322    _element = property(lambda self: self.__elements[0])
323
324    def unmarshall(self, types, strict=True):
325        #import pdb; pdb.set_trace()
326
327        """Convert to python values the current serialized xml element"""
328        # types is a dict of {tag name: convertion function}
329        # strict=False to use default type conversion if not specified
330        # example: types={'p': {'a': int,'b': int}, 'c': [{'d':str}]}
331        #   expected xml: <p><a>1</a><b>2</b></p><c><d>hola</d><d>chau</d>
332        #   returnde value: {'p': {'a':1,'b':2}, `'c':[{'d':'hola'},{'d':'chau'}]}
333        d = {}
334        for node in self():
335            name = str(node.get_local_name())
336            ref_name_type = None
337            # handle multirefs: href="#id0"
338            if 'href' in node.attributes().keys():
339                href = node['href'][1:]
340                for ref_node in self(root=True)("multiRef"):
341                    if ref_node['id'] == href:
342                        node = ref_node
343                        ref_name_type = ref_node['xsi:type'].split(":")[1]
344                        break
345
346            try:
347                if isinstance(types, dict):
348                    fn = types[name]
349                    # custom array only in the response (not defined in the WSDL):
350                    # <results soapenc:arrayType="xsd:string[199]>
351                    if any([k for k,v in node[:] if 'arrayType' in k]) and not isinstance(fn, list):
352                        fn = [fn]
353                else:
354                    fn = types
355            except (KeyError, ) as e:
356                xmlns = node['xmlns'] or node.get_namespace_uri(node.get_prefix())
357                if 'xsi:type' in node.attributes().keys():
358                    xsd_type = node['xsi:type'].split(":")[1]
359                    try:
360                        # get fn type from SOAP-ENC:arrayType="xsd:string[28]"
361                        if xsd_type == 'Array':
362                            array_type = [k for k,v in node[:] if 'arrayType' in k][0]
363                            xsd_type = node[array_type].split(":")[1]
364                            if "[" in xsd_type:
365                                xsd_type = xsd_type[:xsd_type.index("[")]
366                            fn = [REVERSE_TYPE_MAP[xsd_type]]
367                        else:
368                            fn = REVERSE_TYPE_MAP[xsd_type]
369                    except:
370                        fn = None  # ignore multirefs!
371                elif xmlns == "http://www.w3.org/2001/XMLSchema":
372                    # self-defined schema, return the SimpleXMLElement
373                    # TODO: parse to python types if <s:element ref="s:schema"/>
374                    fn = None
375                elif None in types:
376                    # <s:any/>, return the SimpleXMLElement
377                    # TODO: check position of None if inside <s:sequence>
378                    fn = None
379                elif strict:
380                    raise TypeError("Tag: %s invalid (type not found)" % (name,))
381                else:
382                    # if not strict, use default type conversion
383                    fn = str
384
385            if isinstance(fn, list):
386                # append to existing list (if any) - unnested dict arrays -
387                value = d.setdefault(name, [])
388                children = node.children()
389                # TODO: check if this was really needed (get first child only)
390                ##if len(fn[0]) == 1 and children:
391                ##    children = children()
392                if fn and not isinstance(fn[0], dict):
393                    # simple arrays []
394                    for child in (children or []):
395                        tmp_dict = child.unmarshall(fn[0], strict)
396                        value.extend(tmp_dict.values())
397                elif (self.__jetty and len(fn[0]) > 1):
398                    # Jetty array style support [{k, v}]
399                    for parent in node:
400                        tmp_dict = {}    # unmarshall each value & mix
401                        for child in (node.children() or []):
402                            tmp_dict.update(child.unmarshall(fn[0], strict))
403                        value.append(tmp_dict)
404                else:  # .Net / Java
405                    for child in (children or []):
406                        value.append(child.unmarshall(fn[0], strict))
407
408            elif isinstance(fn, tuple):
409                value = []
410                _d = {}
411                children = node.children()
412                as_dict = len(fn) == 1 and isinstance(fn[0], dict)
413
414                for child in (children and children() or []):  # Readability counts
415                    if as_dict:
416                        _d.update(child.unmarshall(fn[0], strict))  # Merging pairs
417                    else:
418                        value.append(child.unmarshall(fn[0], strict))
419                if as_dict:
420                    value.append(_d)
421
422                if name in d:
423                    _tmp = list(d[name])
424                    _tmp.extend(value)
425                    value = tuple(_tmp)
426                else:
427                    value = tuple(value)
428
429            elif isinstance(fn, dict):
430                ##if ref_name_type is not None:
431                ##    fn = fn[ref_name_type]
432                children = node.children()
433                value = children and children.unmarshall(fn, strict)
434            else:
435                if fn is None:  # xsd:anyType not unmarshalled
436                    value = node
437                elif unicode(node) or (fn == str and unicode(node) != ''):
438                    try:
439                        # get special deserialization function (if any)
440                        fn = TYPE_UNMARSHAL_FN.get(fn, fn)
441                        if fn == str:
442                            # always return an unicode object:
443                            # (avoid encoding errors in py<3!)
444                            value = unicode(node)
445                        else:
446                            value = fn(unicode(node))
447                    except (ValueError, TypeError) as e:
448                        raise ValueError("Tag: %s: %s" % (name, e))
449                else:
450                    value = None
451            d[name] = value
452        return d
453
454    def _update_ns(self, name):
455        """Replace the defined namespace alias with tohse used by the client."""
456        pref = self.__ns_rx.search(name)
457        if pref:
458            pref = pref.groups()[0]
459            try:
460                name = name.replace(pref, self.__namespaces_map[pref])
461            except KeyError:
462                log.warning('Unknown namespace alias %s' % name)
463        return name
464
465    def marshall(self, name, value, add_child=True, add_comments=False,
466                 ns=False, add_children_ns=True):
467        """Analyze python value and add the serialized XML element using tag name"""
468        # Change node name to that used by a client
469        name = self._update_ns(name)
470
471        if isinstance(value, dict):  # serialize dict (<key>value</key>)
472            # for the first parent node, use the document target namespace
473            # (ns==True) or use the namespace string uri if passed (elements)
474            child = add_child and self.add_child(name, ns=ns) or self
475            for k, v in value.items():
476                if not add_children_ns:
477                    ns = False
478                elif hasattr(value, 'namespaces'):
479                    # for children, use the wsdl element target namespace:
480                    ns = value.namespaces.get(k)
481                else:
482                    # simple type
483                    ns = None
484                child.marshall(k, v, add_comments=add_comments, ns=ns)
485        elif isinstance(value, tuple):  # serialize tuple (<key>value</key>)
486            child = add_child and self.add_child(name, ns=ns) or self
487            if not add_children_ns:
488                ns = False
489            for k, v in value:
490                getattr(self, name).marshall(k, v, add_comments=add_comments, ns=ns)
491        elif isinstance(value, list):  # serialize lists
492            child = self.add_child(name, ns=ns)
493            if not add_children_ns:
494                ns = False
495            if add_comments:
496                child.add_comment("Repetitive array of:")
497            for t in value:
498                child.marshall(name, t, False, add_comments=add_comments, ns=ns)
499        elif isinstance(value, (xml.dom.minidom.CDATASection, basestring)):  # do not convert strings or unicodes
500            self.add_child(name, value, ns=ns)
501        elif value is None:  # sent a empty tag?
502            self.add_child(name, ns=ns)
503        elif value in TYPE_MAP.keys():
504            # add commented placeholders for simple tipes (for examples/help only)
505            child = self.add_child(name, ns=ns)
506            child.add_comment(TYPE_MAP[value])
507        else:  # the rest of object types are converted to string
508            # get special serialization function (if any)
509            fn = TYPE_MARSHAL_FN.get(type(value), str)
510            self.add_child(name, fn(value), ns=ns)
511
512    def import_node(self, other):
513        x = self.__document.importNode(other._element, True)  # deep copy
514        self._element.appendChild(x)
515
516    def write_c14n(self, output=None, exclusive=True):
517        "Generate the canonical version of the XML node"
518        from . import c14n
519        xml = c14n.Canonicalize(self._element, output,
520                                unsuppressedPrefixes=[] if exclusive else None)
521        return xml
522