1
2#
3# spyne - Copyright (C) Spyne contributors.
4#
5# This library is free software; you can redistribute it and/or
6# modify it under the terms of the GNU Lesser General Public
7# License as published by the Free Software Foundation; either
8# version 2.1 of the License, or (at your option) any later version.
9#
10# This library is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13# Lesser General Public License for more details.
14#
15# You should have received a copy of the GNU Lesser General Public
16# License along with this library; if not, write to the Free Software
17# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301
18#
19
20import logging
21logger = logging.getLogger(__name__)
22
23from collections import deque, defaultdict
24
25import spyne.interface
26
27from spyne import EventManager, MethodDescriptor
28from spyne.util import six
29from spyne.model import ModelBase, Array, Iterable, ComplexModelBase
30from spyne.model.complex import XmlModifier
31from spyne.const import xml as namespace
32
33
34class Interface(object):
35    """The ``Interface`` class holds all information needed to build an
36    interface document.
37
38    :param app: A :class:`spyne.application.Application` instance.
39    """
40
41    def __init__(self, app=None, import_base_namespaces=False):
42        self.__ns_counter = 0
43        self.__app = None
44        self.url = None
45        self.classes = {}
46        self.imports = {}
47        self.service_method_map = {}
48        self.method_id_map = {}
49        self.nsmap = {}
50        self.prefmap = {}
51        self.member_methods = deque()
52        self.method_descriptor_id_to_key = {}
53        self.service_attrs = defaultdict(dict)
54
55        self.import_base_namespaces = import_base_namespaces
56        self.app = app
57
58    def set_app(self, value):
59        assert self.__app is None, "One interface instance can belong to only " \
60                                   "one application instance."
61
62        self.__app = value
63        self.reset_interface()
64        self.populate_interface()
65
66    def get_app(self):
67        return self.__app
68
69    app = property(get_app, set_app)
70
71    @property
72    def services(self):
73        if self.__app:
74            return self.__app.services
75        return []
76
77    def reset_interface(self):
78        self.classes = {}
79        self.imports = {self.get_tns(): set()}
80        self.service_method_map = {}
81        self.method_id_map = {}
82        self.nsmap = dict(namespace.NSMAP)
83        self.prefmap = dict(namespace.PREFMAP)
84        self.member_methods = deque()
85
86        self.nsmap['tns'] = self.get_tns()
87        self.prefmap[self.get_tns()] = 'tns'
88        self.deps = defaultdict(set)
89
90    def has_class(self, cls):
91        """Returns true if the given class is already included in the interface
92        object somewhere."""
93
94        ns = cls.get_namespace()
95        tn = cls.get_type_name()
96
97        key = '{%s}%s' % (ns, tn)
98        c = self.classes.get(key)
99        if c is None:
100            return False
101
102        if issubclass(c, ComplexModelBase) and \
103                                              issubclass(cls, ComplexModelBase):
104            o1 = getattr(cls, '__orig__', None) or cls
105            o2 = getattr(c, '__orig__', None) or c
106
107            if o1 is o2:
108                return True
109
110            # So that "Array"s and "Iterable"s don't conflict.
111            if set((o1, o2)) == set((Array, Iterable)):
112                return True
113
114            raise ValueError("classes %r and %r have conflicting names: '%s'" %
115                                                                  (cls, c, key))
116        return True
117
118    def get_class(self, key):
119        """Returns the class definition that corresponds to the given key.
120        Keys are in '{namespace}class_name' form.
121
122        Not meant to be overridden.
123        """
124        return self.classes[key]
125
126    def get_class_instance(self, key):
127        """Returns the default class instance that corresponds to the given key.
128        Keys are in '{namespace}class_name' form, a.k.a. XML QName format.
129        Classes should not enforce arguments to the constructor.
130
131        Not meant to be overridden.
132        """
133        return self.classes[key]()
134
135    def get_name(self):
136        """Returns service name that is seen in the name attribute of the
137        definitions tag.
138
139        Not meant to be overridden.
140        """
141
142        if self.app:
143            return self.app.name
144
145    def get_tns(self):
146        """Returns default namespace that is seen in the targetNamespace
147        attribute of the definitions tag.
148
149        Not meant to be overridden.
150        """
151        if self.app:
152            return self.app.tns
153
154    def add_method(self, method):
155        """Generator method that adds the given method descriptor to the
156        interface. Also extracts and yields all the types found in there.
157
158        :param method: A :class:`MethodDescriptor` instance
159        :returns: Sequence of :class:`spyne.model.ModelBase` subclasses.
160        """
161
162        if not (method.in_header is None):
163            if not isinstance(method.in_header, (list, tuple)):
164                method.in_header = (method.in_header,)
165
166            for in_header in method.in_header:
167                in_header.resolve_namespace(in_header, self.get_tns())
168                if method.aux is None:
169                    yield in_header
170                in_header_ns = in_header.get_namespace()
171                if in_header_ns != self.get_tns() and \
172                                             self.is_valid_import(in_header_ns):
173                    self.imports[self.get_tns()].add(in_header_ns)
174
175        if not (method.out_header is None):
176            if not isinstance(method.out_header, (list, tuple)):
177                method.out_header = (method.out_header,)
178
179            for out_header in method.out_header:
180                out_header.resolve_namespace(out_header, self.get_tns())
181                if method.aux is None:
182                    yield out_header
183                out_header_ns = out_header.get_namespace()
184                if out_header_ns != self.get_tns() and \
185                                            self.is_valid_import(out_header_ns):
186                    self.imports[self.get_tns()].add(out_header_ns)
187
188        if method.faults is None:
189            method.faults = []
190        elif not (isinstance(method.faults, (list, tuple))):
191            method.faults = (method.faults,)
192
193        for fault in method.faults:
194            fault.__namespace__ = self.get_tns()
195            fault.resolve_namespace(fault, self.get_tns())
196            if method.aux is None:
197                yield fault
198
199        method.in_message.resolve_namespace(method.in_message, self.get_tns())
200        in_message_ns = method.in_message.get_namespace()
201        if in_message_ns != self.get_tns() and \
202                                            self.is_valid_import(in_message_ns):
203            self.imports[self.get_tns()].add(method.in_message.get_namespace())
204
205        if method.aux is None:
206            yield method.in_message
207
208        method.out_message.resolve_namespace(method.out_message, self.get_tns())
209        assert not method.out_message.get_type_name() is method.out_message.Empty
210
211        out_message_ns = method.out_message.get_namespace()
212        if out_message_ns != self.get_tns() and \
213                                           self.is_valid_import(out_message_ns):
214            self.imports[self.get_tns()].add(out_message_ns)
215
216        if method.aux is None:
217            yield method.out_message
218
219        for p in method.patterns:
220            p.endpoint = method
221
222    def process_method(self, s, method):
223        assert isinstance(method, MethodDescriptor)
224
225        method_key = u'{%s}%s' % (self.app.tns, method.name)
226
227        if issubclass(s, ComplexModelBase):
228            method_object_name = method.name.split('.', 1)[0]
229            if s.get_type_name() != method_object_name:
230                method_key = u'{%s}%s.%s' % (self.app.tns, s.get_type_name(),
231                                                                    method.name)
232
233        key = method.gen_interface_key(s)
234        if key in self.method_id_map:
235            c = self.method_id_map[key].parent_class
236            if c is None:
237                pass
238
239            elif c is s:
240                pass
241
242            elif c.__orig__ is None:
243                assert c is s.__orig__, "%r.%s conflicts with %r.%s" % \
244                                        (c, key, s.__orig__, key)
245            elif s.__orig__ is None:
246                assert c.__orig__ is s, "%r.%s conflicts with %r.%s" % \
247                                        (c.__orig__, key, s, key)
248            else:
249                assert c.__orig__ is s.__orig__, "%r.%s conflicts with %r.%s" % \
250                                        (c.__orig__, key, s.__orig__, key)
251            return
252
253        logger.debug('  adding method %s.%s to match %r tag.',
254               method.get_owner_name(s), six.get_function_name(method.function),
255                                                                     method_key)
256
257        self.method_id_map[key] = method
258
259        val = self.service_method_map.get(method_key, None)
260        if val is None:
261            val = self.service_method_map[method_key] = []
262
263        if len(val) == 0:
264            val.append(method)
265
266        elif method.aux is not None:
267            val.append(method)
268
269        elif val[0].aux is not None:
270            val.insert(method, 0)
271
272        else:
273            om = val[0]
274            os = om.service_class
275            if os is None:
276                os = om.parent_class
277            raise ValueError("\nThe message %r defined in both '%s.%s'"
278                                                         " and '%s.%s'"
279                                    % (method.name, s.__module__,  s.__name__,
280                                                   os.__module__, os.__name__))
281
282    def check_method(self, method):
283        """Override this if you need to cherry-pick methods added to the
284        interface document."""
285
286        return True
287
288    def populate_interface(self, types=None):
289        """Harvests the information stored in individual classes' _type_info
290        dictionaries. It starts from function definitions and includes only
291        the used objects.
292        """
293
294        # populate types
295        for s in self.services:
296            logger.debug("populating %s types...", s.get_internal_key())
297
298            for method in s.public_methods.values():
299                if method.in_header is None:
300                    method.in_header = s.__in_header__
301
302                if method.out_header is None:
303                    method.out_header = s.__out_header__
304
305                if method.aux is None:
306                    method.aux = s.__aux__
307
308                if method.aux is not None:
309                    method.aux.methods.append(method.gen_interface_key(s))
310
311                if not self.check_method(method):
312                    logger.debug("method %s' discarded by check_method",
313                                                               method.class_key)
314                    continue
315
316                logger.debug("  enumerating classes for method '%s'",
317                                                               method.class_key)
318                for cls in self.add_method(method):
319                    self.add_class(cls)
320
321        # populate additional types
322        for c in self.app.classes:
323            self.add_class(c)
324
325        # populate call routes for service methods
326        for s in self.services:
327            self.service_attrs[s]['tns'] = self.get_tns()
328            logger.debug("populating '%s.%s' routes...", s.__module__,
329                                                                     s.__name__)
330            for method in s.public_methods.values():
331                self.process_method(s, method)
332
333        # populate call routes for member methods
334        for cls, method in self.member_methods:
335            should_we = True
336            if method.static_when is not None:
337                should_we = method.static_when(self.app)
338                logger.debug("static_when returned %r for %s "
339                     "while populating methods", should_we, method.internal_key)
340
341            if should_we:
342                s = method.service_class
343                if s is not None:
344                    if method.in_header is None:
345                        method.in_header = s.__in_header__
346
347                    if method.out_header is None:
348                        method.out_header = s.__out_header__
349
350                    # FIXME: There's no need to process aux info here as it's
351                    # not currently known how to write aux member methods in the
352                    # first place.
353
354                self.process_method(cls.__orig__ or cls, method)
355
356        # populate method descriptor id to method key map
357        self.method_descriptor_id_to_key = dict(((id(v[0]), k)
358                                    for k,v in self.service_method_map.items()))
359
360        logger.debug("From this point on, you're not supposed to make any "
361                     "changes to the class and method structure of the exposed "
362                     "services.")
363
364    tns = property(get_tns)
365
366    def get_namespace_prefix(self, ns):
367        """Returns the namespace prefix for the given namespace. Creates a new
368        one automatically if it doesn't exist.
369
370        Not meant to be overridden.
371        """
372
373        if not (isinstance(ns, str) or isinstance(ns, six.text_type)):
374            raise TypeError(ns)
375
376        if not (ns in self.prefmap):
377            pref = "s%d" % self.__ns_counter
378            while pref in self.nsmap:
379                self.__ns_counter += 1
380                pref = "s%d" % self.__ns_counter
381
382            self.prefmap[ns] = pref
383            self.nsmap[pref] = ns
384
385            self.__ns_counter += 1
386
387        else:
388            pref = self.prefmap[ns]
389
390        return pref
391
392    def add_class(self, cls, add_parent=True):
393        if self.has_class(cls):
394            return
395
396        ns = cls.get_namespace()
397        tn = cls.get_type_name()
398
399        assert ns is not None, ('either assign a namespace to the class or call'
400                        ' cls.resolve_namespace(cls, "some_default_ns") on it.')
401
402        if not (ns in self.imports) and self.is_valid_import(ns):
403            self.imports[ns] = set()
404
405        class_key = '{%s}%s' % (ns, tn)
406        logger.debug('    adding class %r for %r', repr(cls), class_key)
407
408        assert class_key not in self.classes, ("Somehow, you're trying to "
409            "overwrite %r by %r for class key %r." %
410                                      (self.classes[class_key], cls, class_key))
411
412        assert not (cls.get_type_name() is cls.Empty), cls
413
414        self.deps[cls]  # despite the appearances, this is not totally useless.
415        self.classes[class_key] = cls
416        if ns == self.get_tns():
417            self.classes[tn] = cls
418
419        # add parent class
420        extends = getattr(cls, '__extends__', None)
421        while extends is not None and \
422                                   (extends.get_type_name() is ModelBase.Empty):
423            extends = getattr(extends, '__extends__', None)
424
425        if add_parent and extends is not None:
426            assert issubclass(extends, ModelBase)
427            self.deps[cls].add(extends)
428            self.add_class(extends)
429            parent_ns = extends.get_namespace()
430            if parent_ns != ns and not parent_ns in self.imports[ns] and \
431                                                self.is_valid_import(parent_ns):
432                self.imports[ns].add(parent_ns)
433                logger.debug("    importing %r to %r because %r extends %r",
434                                            parent_ns, ns, cls.get_type_name(),
435                                            extends.get_type_name())
436
437        # add fields
438        if issubclass(cls, ComplexModelBase):
439            for k, v in cls._type_info.items():
440                if v is None:
441                    continue
442
443                self.deps[cls].add(v)
444
445                logger.debug("    adding %s.%s = %r", cls.get_type_name(), k, v)
446                if v.get_namespace() is None:
447                    v.resolve_namespace(v, ns)
448
449                self.add_class(v)
450
451                if v.get_namespace() is None and cls.get_namespace() is not None:
452                    v.resolve_namespace(v, cls.get_namespace())
453
454                child_ns = v.get_namespace()
455                if child_ns != ns and not child_ns in self.imports[ns] and \
456                                                 self.is_valid_import(child_ns):
457                    self.imports[ns].add(child_ns)
458                    logger.debug("    importing %r to %r for %s.%s(%r)",
459                                       child_ns, ns, cls.get_type_name(), k, v)
460
461                if issubclass(v, XmlModifier):
462                    self.add_class(v.type)
463
464                    child_ns = v.type.get_namespace()
465                    if child_ns != ns and not child_ns in self.imports[ns] and \
466                                                 self.is_valid_import(child_ns):
467                        self.imports[ns].add(child_ns)
468                        logger.debug("    importing %r to %r for %s.%s(%r)",
469                                    child_ns, ns, v.get_type_name(), k, v.type)
470
471            if cls.Attributes.methods is not None:
472                logger.debug("    populating member methods for '%s.%s'...",
473                                       cls.get_namespace(), cls.get_type_name())
474
475                for method_key, descriptor in cls.Attributes.methods.items():
476                    assert hasattr(cls, method_key)
477
478                    should_we = True
479                    if descriptor.static_when is not None:
480                        should_we = descriptor.static_when(self.app)
481                        logger.debug("static_when returned %r for %s "
482                            "while populating classes",
483                                             should_we, descriptor.internal_key)
484
485                    if should_we:
486                        self.member_methods.append((cls, descriptor))
487                        for c in self.add_method(descriptor):
488                            self.add_class(c)
489
490            if cls.Attributes._subclasses is not None:
491                logger.debug("    adding subclasses of '%s.%s'...",
492                                       cls.get_namespace(), cls.get_type_name())
493
494                for c in cls.Attributes._subclasses:
495                    c.resolve_namespace(c, ns)
496
497                    child_ns = c.get_namespace()
498                    if child_ns == ns:
499                        if not self.has_class(c):
500                            self.add_class(c, add_parent=False)
501                            self.deps[c].add(cls)
502                    else:
503                        logger.debug("    not adding %r to %r because it would "
504                            "cause circular imports because %r extends %r and "
505                            "they don't have the same namespace", child_ns,
506                                     ns, c.get_type_name(), cls.get_type_name())
507
508    def is_valid_import(self, ns):
509        """This will return False for base namespaces unless told otherwise."""
510
511        if ns is None:
512            raise ValueError(ns)
513
514        return self.import_base_namespaces or not (ns in namespace.PREFMAP)
515
516
517class AllYourInterfaceDocuments(object): # AreBelongToUs
518    def __init__(self, interface, wsdl11=None):
519        self.wsdl11 = wsdl11
520        if self.wsdl11 is None and spyne.interface.HAS_WSDL:
521            from spyne.interface.wsdl import Wsdl11
522            self.wsdl11 = Wsdl11(interface)
523
524
525class InterfaceDocumentBase(object):
526    """Base class for all interface document implementations.
527
528    :param interface: A :class:`spyne.interface.InterfaceBase` instance.
529    """
530
531    def __init__(self, interface):
532        self.interface = interface
533        self.event_manager = EventManager(self)
534
535    def build_interface_document(self):
536        """This function is supposed to be called just once, as late as possible
537        into the process start. It builds the interface document and caches it
538        somewhere. The overriding function should never call the overridden
539        function as this may result in the same event firing more than once.
540        """
541
542        raise NotImplementedError('Extend and override.')
543
544    def get_interface_document(self):
545        """This function is called by server transports that try to satisfy the
546        request for the interface document. This should just return a previously
547        cached interface document.
548        """
549
550        raise NotImplementedError('Extend and override.')
551