1# This program is free software; you can redistribute it and/or modify
2# it under the terms of the (LGPL) GNU Lesser General Public License as
3# published by the Free Software Foundation; either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
9# GNU Library Lesser General Public License for more details at
10# ( http://www.gnu.org/licenses/lgpl.html ).
11#
12# You should have received a copy of the GNU Lesser General Public License
13# along with this program; if not, write to the Free Software
14# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
15# written by: Jeff Ortel ( jortel@redhat.com )
16
17"""
18The I{sxbasic} module provides classes that represent
19I{basic} schema objects.
20"""
21
22from suds import *
23from suds.xsd import *
24from suds.xsd.sxbase import *
25from suds.xsd.query import *
26from suds.sax import Namespace
27from suds.transport import TransportError
28from suds.reader import DocumentReader
29from urllib.parse import urljoin
30
31from logging import getLogger
32log = getLogger(__name__)
33
34
35class RestrictionMatcher:
36    """
37    For use with L{NodeFinder} to match restriction.
38    """
39    def match(self, n):
40        return isinstance(n, Restriction)
41
42
43class TypedContent(Content):
44    """
45    Represents any I{typed} content.
46    """
47
48    def __init__(self, *args, **kwargs):
49        Content.__init__(self, *args, **kwargs)
50        self.resolved_cache = {}
51
52    def resolve(self, nobuiltin=False):
53        """
54        Resolve the node's type reference and return the referenced type node.
55
56        Returns self if the type is defined locally, e.g. as a <complexType>
57        subnode. Otherwise returns the referenced external node.
58        @param nobuiltin: Flag indicating whether resolving to XSD builtin
59            types should not be allowed.
60        @return: The resolved (true) type.
61        @rtype: L{SchemaObject}
62        """
63        cached = self.resolved_cache.get(nobuiltin)
64        if cached is not None:
65            return cached
66        resolved = self.__resolve_type(nobuiltin)
67        self.resolved_cache[nobuiltin] = resolved
68        return resolved
69
70    def __resolve_type(self, nobuiltin=False):
71        """
72        Private resolve() worker without any result caching.
73        @param nobuiltin: Flag indicating whether resolving to XSD builtin
74            types should not be allowed.
75        @return: The resolved (true) type.
76        @rtype: L{SchemaObject}
77
78        Implementation note:
79          Note that there is no need for a recursive implementation here since
80        a node can reference an external type node but there is no way using
81        WSDL to then make that type node actually be a reference to a different
82        type node.
83        """
84        qref = self.qref()
85        if qref is None:
86            return self
87        query = TypeQuery(qref)
88        query.history = [self]
89        log.debug('%s, resolving: %s\n using:%s', self.id, qref, query)
90        resolved = query.execute(self.schema)
91        if resolved is None:
92            log.debug(self.schema)
93            raise TypeNotFound(qref)
94        if resolved.builtin() and nobuiltin:
95            return self
96        return resolved
97
98    def qref(self):
99        """
100        Get the I{type} qualified reference to the referenced XSD type.
101        This method takes into account simple types defined through
102        restriction which are detected by determining that self is simple
103        (len=0) and by finding a restriction child.
104        @return: The I{type} qualified reference.
105        @rtype: qref
106        """
107        qref = self.type
108        if qref is None and len(self) == 0:
109            ls = []
110            m = RestrictionMatcher()
111            finder = NodeFinder(m, 1)
112            finder.find(self, ls)
113            if len(ls):
114                return ls[0].ref
115        return qref
116
117
118class Complex(SchemaObject):
119    """
120    Represents an (XSD) schema <xs:complexType/> node.
121    @cvar childtags: A list of valid child node names.
122    @type childtags: (I{str},...)
123    """
124
125    def childtags(self):
126        return 'attribute', 'attributeGroup', 'sequence', 'all', 'choice',  \
127            'complexContent', 'simpleContent', 'any', 'group'
128
129    def description(self):
130        return ('name',)
131
132    def extension(self):
133        for c in self.rawchildren:
134            if c.extension():
135                return True
136        return False
137
138    def mixed(self):
139        for c in self.rawchildren:
140            if isinstance(c, SimpleContent) and c.mixed():
141                return True
142        return False
143
144
145class Group(SchemaObject):
146    """
147    Represents an (XSD) schema <xs:group/> node.
148    @cvar childtags: A list of valid child node names.
149    @type childtags: (I{str},...)
150    """
151
152    def childtags(self):
153        return 'sequence', 'all', 'choice'
154
155    def dependencies(self):
156        deps = []
157        midx = None
158        if self.ref is not None:
159            query = GroupQuery(self.ref)
160            g = query.execute(self.schema)
161            if g is None:
162                log.debug(self.schema)
163                raise TypeNotFound(self.ref)
164            deps.append(g)
165            midx = 0
166        return midx, deps
167
168    def merge(self, other):
169        SchemaObject.merge(self, other)
170        self.rawchildren = other.rawchildren
171
172    def description(self):
173        return 'name', 'ref'
174
175
176class AttributeGroup(SchemaObject):
177    """
178    Represents an (XSD) schema <xs:attributeGroup/> node.
179    @cvar childtags: A list of valid child node names.
180    @type childtags: (I{str},...)
181    """
182
183    def childtags(self):
184        return 'attribute', 'attributeGroup'
185
186    def dependencies(self):
187        deps = []
188        midx = None
189        if self.ref is not None:
190            query = AttrGroupQuery(self.ref)
191            ag = query.execute(self.schema)
192            if ag is None:
193                log.debug(self.schema)
194                raise TypeNotFound(self.ref)
195            deps.append(ag)
196            midx = 0
197        return midx, deps
198
199    def merge(self, other):
200        SchemaObject.merge(self, other)
201        self.rawchildren = other.rawchildren
202
203    def description(self):
204        return 'name', 'ref'
205
206
207class Simple(SchemaObject):
208    """
209    Represents an (XSD) schema <xs:simpleType/> node.
210    """
211
212    def childtags(self):
213        return 'restriction', 'any', 'list'
214
215    def enum(self):
216        for child, ancestry in self.children():
217            if isinstance(child, Enumeration):
218                return True
219        return False
220
221    def mixed(self):
222        return len(self)
223
224    def description(self):
225        return ('name',)
226
227    def extension(self):
228        for c in self.rawchildren:
229            if c.extension():
230                return True
231        return False
232
233    def restriction(self):
234        for c in self.rawchildren:
235            if c.restriction():
236                return True
237        return False
238
239
240class List(SchemaObject):
241    """
242    Represents an (XSD) schema <xs:list/> node.
243    """
244
245    def childtags(self):
246        return ()
247
248    def description(self):
249        return ('name',)
250
251    def xslist(self):
252        return True
253
254
255class Restriction(SchemaObject):
256    """
257    Represents an (XSD) schema <xs:restriction/> node.
258    """
259
260    def __init__(self, schema, root):
261        SchemaObject.__init__(self, schema, root)
262        self.ref = root.get('base')
263
264    def childtags(self):
265        return 'enumeration', 'attribute', 'attributeGroup'
266
267    def dependencies(self):
268        deps = []
269        midx = None
270        if self.ref is not None:
271            query = TypeQuery(self.ref)
272            super = query.execute(self.schema)
273            if super is None:
274                log.debug(self.schema)
275                raise TypeNotFound(self.ref)
276            if not super.builtin():
277                deps.append(super)
278                midx = 0
279        return midx, deps
280
281    def restriction(self):
282        return True
283
284    def merge(self, other):
285        SchemaObject.merge(self, other)
286        filter = Filter(False, self.rawchildren)
287        self.prepend(self.rawchildren, other.rawchildren, filter)
288
289    def description(self):
290        return ('ref',)
291
292
293class Collection(SchemaObject):
294    """
295    Represents an (XSD) schema collection node:
296        - sequence
297        - choice
298        - all
299    """
300
301    def childtags(self):
302        return 'element', 'sequence', 'all', 'choice', 'any', 'group'
303
304
305class Sequence(Collection):
306    """
307    Represents an (XSD) schema <xs:sequence/> node.
308    """
309    def sequence(self):
310        return True
311
312
313class All(Collection):
314    """
315    Represents an (XSD) schema <xs:all/> node.
316    """
317    def all(self):
318        return True
319
320
321class Choice(Collection):
322    """
323    Represents an (XSD) schema <xs:choice/> node.
324    """
325    def choice(self):
326        return True
327
328
329class ComplexContent(SchemaObject):
330    """
331    Represents an (XSD) schema <xs:complexContent/> node.
332    """
333
334    def childtags(self):
335        return 'attribute', 'attributeGroup', 'extension', 'restriction'
336
337    def extension(self):
338        for c in self.rawchildren:
339            if c.extension():
340                return True
341        return False
342
343    def restriction(self):
344        for c in self.rawchildren:
345            if c.restriction():
346                return True
347        return False
348
349
350class SimpleContent(SchemaObject):
351    """
352    Represents an (XSD) schema <xs:simpleContent/> node.
353    """
354
355    def childtags(self):
356        return 'extension', 'restriction'
357
358    def extension(self):
359        for c in self.rawchildren:
360            if c.extension():
361                return True
362        return False
363
364    def restriction(self):
365        for c in self.rawchildren:
366            if c.restriction():
367                return True
368        return False
369
370    def mixed(self):
371        return len(self)
372
373
374class Enumeration(Content):
375    """
376    Represents an (XSD) schema <xs:enumeration/> node.
377    """
378
379    def __init__(self, schema, root):
380        Content.__init__(self, schema, root)
381        self.name = root.get('value')
382
383    def description(self):
384        return ('name',)
385
386    def enum(self):
387        return True
388
389
390class Element(TypedContent):
391    """
392    Represents an (XSD) schema <xs:element/> node.
393    """
394
395    def __init__(self, schema, root):
396        TypedContent.__init__(self, schema, root)
397        a = root.get('form')
398        if a is not None:
399            self.form_qualified = ( a == 'qualified' )
400        a = self.root.get('nillable')
401        if a is not None:
402            self.nillable = ( a in ('1', 'true') )
403        self.implany()
404
405    def implany(self):
406        """
407        Set the type as any when implicit.
408        An implicit <xs:any/> is when an element has not
409        body and no type defined.
410        @return: self
411        @rtype: L{Element}
412        """
413        if self.type is None and \
414            self.ref is None and \
415            self.root.isempty():
416                self.type = self.anytype()
417        return self
418
419    def childtags(self):
420        return 'attribute', 'simpleType', 'complexType', 'any'
421
422    def extension(self):
423        for c in self.rawchildren:
424            if c.extension():
425                return True
426        return False
427
428    def restriction(self):
429        for c in self.rawchildren:
430            if c.restriction():
431                return True
432        return False
433
434    def dependencies(self):
435        deps = []
436        midx = None
437        e = self.__deref()
438        if e is not None:
439            deps.append(e)
440            midx = 0
441        return midx, deps
442
443    def merge(self, other):
444        SchemaObject.merge(self, other)
445        self.rawchildren = other.rawchildren
446
447    def description(self):
448        return 'name', 'ref', 'type'
449
450    def anytype(self):
451        """ create an xsd:anyType reference """
452        p, u = Namespace.xsdns
453        mp = self.root.findPrefix(u)
454        if mp is None:
455            mp = p
456            self.root.addPrefix(p, u)
457        return ':'.join((mp, 'anyType'))
458
459    def namespace(self, prefix=None):
460        """
461        Get this schema element's target namespace.
462
463        In case of reference elements, the target namespace is defined by the
464        referenced and not the referencing element node.
465
466        @param prefix: The default prefix.
467        @type prefix: str
468        @return: The schema element's target namespace
469        @rtype: (I{prefix},I{URI})
470        """
471        e = self.__deref()
472        if e is not None:
473            return e.namespace(prefix)
474        return super(Element, self).namespace()
475
476    def __deref(self):
477        if self.ref is None:
478            return
479        query = ElementQuery(self.ref)
480        e = query.execute(self.schema)
481        if e is None:
482            log.debug(self.schema)
483            raise TypeNotFound(self.ref)
484        return e
485
486
487class Extension(SchemaObject):
488    """
489    Represents an (XSD) schema <xs:extension/> node.
490    """
491
492    def __init__(self, schema, root):
493        SchemaObject.__init__(self, schema, root)
494        self.ref = root.get('base')
495
496    def childtags(self):
497        return 'attribute', 'attributeGroup', 'sequence', 'all', 'choice',  \
498            'group'
499
500    def dependencies(self):
501        deps = []
502        midx = None
503        if self.ref is not None:
504            query = TypeQuery(self.ref)
505            super = query.execute(self.schema)
506            if super is None:
507                log.debug(self.schema)
508                raise TypeNotFound(self.ref)
509            if not super.builtin():
510                deps.append(super)
511                midx = 0
512        return midx, deps
513
514    def merge(self, other):
515        SchemaObject.merge(self, other)
516        filter = Filter(False, self.rawchildren)
517        self.prepend(self.rawchildren, other.rawchildren, filter)
518
519    def extension(self):
520        return self.ref is not None
521
522    def description(self):
523        return ('ref',)
524
525
526class Import(SchemaObject):
527    """
528    Represents an (XSD) schema <xs:import/> node.
529    @cvar locations: A dictionary of namespace locations.
530    @type locations: dict
531    @ivar ns: The imported namespace.
532    @type ns: str
533    @ivar location: The (optional) location.
534    @type location: namespace-uri
535    @ivar opened: Opened and I{imported} flag.
536    @type opened: boolean
537    """
538
539    locations = {}
540
541    @classmethod
542    def bind(cls, ns, location=None):
543        """
544        Bind a namespace to a schema location (URI).
545        This is used for imports that don't specify a schemaLocation.
546        @param ns: A namespace-uri.
547        @type ns: str
548        @param location: The (optional) schema location for the
549            namespace.  (default=ns).
550        @type location: str
551        """
552        if location is None:
553            location = ns
554        cls.locations[ns] = location
555
556    def __init__(self, schema, root):
557        SchemaObject.__init__(self, schema, root)
558        self.ns = (None, root.get('namespace'))
559        self.location = root.get('schemaLocation')
560        if self.location is None:
561            self.location = self.locations.get(self.ns[1])
562        self.opened = False
563
564    def open(self, options):
565        """
566        Open and import the refrenced schema.
567        @param options: An options dictionary.
568        @type options: L{options.Options}
569        @return: The referenced schema.
570        @rtype: L{Schema}
571        """
572        if self.opened:
573            return
574        self.opened = True
575        log.debug('%s, importing ns="%s", location="%s"', self.id, self.ns[1], self.location)
576        result = self.locate()
577        if result is None:
578            if self.location is None:
579                log.debug('imported schema (%s) not-found', self.ns[1])
580            else:
581                result = self.download(options)
582        log.debug('imported:\n%s', result)
583        return result
584
585    def locate(self):
586        """ find the schema locally """
587        if self.ns[1] != self.schema.tns[1]:
588            return self.schema.locate(self.ns)
589
590    def download(self, options):
591        """ download the schema """
592        url = self.location
593        try:
594            if '://' not in url:
595                url = urljoin(self.schema.baseurl, url)
596            reader = DocumentReader(options)
597            d = reader.open(url)
598            root = d.root()
599            root.set('url', url)
600            return self.schema.instance(root, url, options)
601        except TransportError:
602            msg = 'imported schema (%s) at (%s), failed' % (self.ns[1], url)
603            log.error('%s, %s', self.id, msg, exc_info=True)
604            raise Exception(msg)
605
606    def description(self):
607        return 'ns', 'location'
608
609
610class Include(SchemaObject):
611    """
612    Represents an (XSD) schema <xs:include/> node.
613    @ivar location: The (optional) location.
614    @type location: namespace-uri
615    @ivar opened: Opened and I{imported} flag.
616    @type opened: boolean
617    """
618
619    locations = {}
620
621    def __init__(self, schema, root):
622        SchemaObject.__init__(self, schema, root)
623        self.location = root.get('schemaLocation')
624        if self.location is None:
625            self.location = self.locations.get(self.ns[1])
626        self.opened = False
627
628    def open(self, options):
629        """
630        Open and include the refrenced schema.
631        @param options: An options dictionary.
632        @type options: L{options.Options}
633        @return: The referenced schema.
634        @rtype: L{Schema}
635        """
636        if self.opened:
637            return
638        self.opened = True
639        log.debug('%s, including location="%s"', self.id, self.location)
640        result = self.download(options)
641        log.debug('included:\n%s', result)
642        return result
643
644    def download(self, options):
645        """ download the schema """
646        url = self.location
647        try:
648            if '://' not in url:
649                url = urljoin(self.schema.baseurl, url)
650            reader = DocumentReader(options)
651            d = reader.open(url)
652            root = d.root()
653            root.set('url', url)
654            self.__applytns(root)
655            return self.schema.instance(root, url, options)
656        except TransportError:
657            msg = 'include schema at (%s), failed' % url
658            log.error('%s, %s', self.id, msg, exc_info=True)
659            raise Exception(msg)
660
661    def __applytns(self, root):
662        """ make sure included schema has same tns. """
663        TNS = 'targetNamespace'
664        tns = root.get(TNS)
665        if tns is None:
666            tns = self.schema.tns[1]
667            root.set(TNS, tns)
668        else:
669            if self.schema.tns[1] != tns:
670                raise Exception('%s mismatch' % TNS)
671
672
673    def description(self):
674        return 'location'
675
676
677class Attribute(TypedContent):
678    """
679    Represents an (XSD) <attribute/> node.
680    """
681
682    def __init__(self, schema, root):
683        TypedContent.__init__(self, schema, root)
684        self.use = root.get('use', default='')
685
686    def childtags(self):
687        return ('restriction',)
688
689    def isattr(self):
690        return True
691
692    def get_default(self):
693        """
694        Gets the <xs:attribute default=""/> attribute value.
695        @return: The default value for the attribute
696        @rtype: str
697        """
698        return self.root.get('default', default='')
699
700    def optional(self):
701        return self.use != 'required'
702
703    def dependencies(self):
704        deps = []
705        midx = None
706        if self.ref is not None:
707            query = AttrQuery(self.ref)
708            a = query.execute(self.schema)
709            if a is None:
710                log.debug(self.schema)
711                raise TypeNotFound(self.ref)
712            deps.append(a)
713            midx = 0
714        return midx, deps
715
716    def description(self):
717        return 'name', 'ref', 'type'
718
719
720class Any(Content):
721    """
722    Represents an (XSD) <any/> node.
723    """
724
725    def get_child(self, name):
726        root = self.root.clone()
727        root.set('note', 'synthesized (any) child')
728        child = Any(self.schema, root)
729        return child, []
730
731    def get_attribute(self, name):
732        root = self.root.clone()
733        root.set('note', 'synthesized (any) attribute')
734        attribute = Any(self.schema, root)
735        return attribute, []
736
737    def any(self):
738        return True
739
740
741class Factory:
742    """
743    @cvar tags: A factory to create object objects based on tag.
744    @type tags: {tag:fn,}
745    """
746
747    tags = {
748        'import' : Import,
749        'include' : Include,
750        'complexType' : Complex,
751        'group' : Group,
752        'attributeGroup' : AttributeGroup,
753        'simpleType' : Simple,
754        'list' : List,
755        'element' : Element,
756        'attribute' : Attribute,
757        'sequence' : Sequence,
758        'all' : All,
759        'choice' : Choice,
760        'complexContent' : ComplexContent,
761        'simpleContent' : SimpleContent,
762        'restriction' : Restriction,
763        'enumeration' : Enumeration,
764        'extension' : Extension,
765        'any' : Any,
766    }
767
768    @classmethod
769    def maptag(cls, tag, fn):
770        """
771        Map (override) tag => I{class} mapping.
772        @param tag: An XSD tag name.
773        @type tag: str
774        @param fn: A function or class.
775        @type fn: fn|class.
776        """
777        cls.tags[tag] = fn
778
779    @classmethod
780    def create(cls, root, schema):
781        """
782        Create an object based on the root tag name.
783        @param root: An XML root element.
784        @type root: L{Element}
785        @param schema: A schema object.
786        @type schema: L{schema.Schema}
787        @return: The created object.
788        @rtype: L{SchemaObject}
789        """
790        fn = cls.tags.get(root.name)
791        if fn is not None:
792            return fn(schema, root)
793
794    @classmethod
795    def build(cls, root, schema, filter=('*',)):
796        """
797        Build an xsobject representation.
798        @param root: An schema XML root.
799        @type root: L{sax.element.Element}
800        @param filter: A tag filter.
801        @type filter: [str,...]
802        @return: A schema object graph.
803        @rtype: L{sxbase.SchemaObject}
804        """
805        children = []
806        for node in root.getChildren(ns=Namespace.xsdns):
807            if '*' in filter or node.name in filter:
808                child = cls.create(node, schema)
809                if child is None:
810                    continue
811                children.append(child)
812                c = cls.build(node, schema, child.childtags())
813                child.rawchildren = c
814        return children
815
816    @classmethod
817    def collate(cls, children):
818        imports = []
819        elements = {}
820        attributes = {}
821        types = {}
822        groups = {}
823        agrps = {}
824        for c in children:
825            if isinstance(c, (Import, Include)):
826                imports.append(c)
827                continue
828            if isinstance(c, Attribute):
829                attributes[c.qname] = c
830                continue
831            if isinstance(c, Element):
832                elements[c.qname] = c
833                continue
834            if isinstance(c, Group):
835                groups[c.qname] = c
836                continue
837            if isinstance(c, AttributeGroup):
838                agrps[c.qname] = c
839                continue
840            types[c.qname] = c
841        for i in imports:
842            children.remove(i)
843        return children, imports, attributes, elements, types, groups, agrps
844
845
846#######################################################
847# Static Import Bindings :-(
848#######################################################
849Import.bind(
850    'http://schemas.xmlsoap.org/soap/encoding/',
851    'suds://schemas.xmlsoap.org/soap/encoding/')
852Import.bind(
853    'http://www.w3.org/XML/1998/namespace',
854    'http://www.w3.org/2001/xml.xsd')
855Import.bind(
856    'http://www.w3.org/2001/XMLSchema',
857    'http://www.w3.org/2001/XMLSchema.xsd')
858