1
2#
3# spyne - Copyright (C) Spyne contributors.
4#
5# This library is free software; you can redistribute it and/or
6# modify it under the terms of the GNU Lesser General Public
7# License as published by the Free Software Foundation; either
8# version 2.1 of the License, or (at your option) any later version.
9#
10# This library is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13# Lesser General Public License for more details.
14#
15# You should have received a copy of the GNU Lesser General Public
16# License along with this library; if not, write to the Free Software
17# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301
18#
19
20
21"""
22A server that uses http as transport via wsgi. It doesn't contain any server
23logic.
24"""
25
26import logging
27logger = logging.getLogger(__name__)
28
29import cgi
30import threading
31
32from inspect import isgenerator
33from itertools import chain
34
35from spyne import Address
36from spyne.util.six.moves.http_cookies import SimpleCookie
37from spyne.util.six.moves.urllib.parse import unquote, quote
38
39from spyne import File, Fault
40from spyne.application import get_fault_string_from_exception
41from spyne.auxproc import process_contexts
42from spyne.error import RequestTooLongError
43from spyne.protocol.http import HttpRpc
44from spyne.server.http import HttpBase, HttpMethodContext, HttpTransportContext
45from spyne.util.odict import odict
46from spyne.util.address import address_parser
47
48from spyne.const.ansi_color import LIGHT_GREEN
49from spyne.const.ansi_color import END_COLOR
50from spyne.const.http import HTTP_200
51from spyne.const.http import HTTP_404
52from spyne.const.http import HTTP_500
53
54
55try:
56    from spyne.protocol.soap.mime import apply_mtom
57except ImportError as _import_error_1:
58    _local_import_error_1 = _import_error_1  # python 3 workaround
59    def apply_mtom(*args, **kwargs):
60        raise _local_import_error_1
61
62try:
63    from werkzeug.formparser import parse_form_data
64except ImportError as _import_error_2:
65    _local_import_error_2 = _import_error_2  # python 3 workaround
66    def parse_form_data(*args, **kwargs):
67        raise _local_import_error_2
68
69
70def _reconstruct_url(environ, protocol=True, server_name=True, path=True,
71                                                             query_string=True):
72    """Rebuilds the calling url from values found in the
73    environment.
74
75    This algorithm was found via PEP 333, the wsgi spec and
76    contributed by Ian Bicking.
77    """
78
79    url = ''
80    if protocol:
81        url = environ['wsgi.url_scheme'] + '://'
82
83    if server_name:
84        if environ.get('HTTP_HOST'):
85            url += environ['HTTP_HOST']
86
87        else:
88            url += environ['SERVER_NAME']
89
90            if environ['wsgi.url_scheme'] == 'https':
91                if environ['SERVER_PORT'] != '443':
92                    url += ':' + environ['SERVER_PORT']
93
94            else:
95                if environ['SERVER_PORT'] != '80':
96                    url += ':' + environ['SERVER_PORT']
97
98    if path:
99        if (quote(environ.get('SCRIPT_NAME', '')) == '/' and
100            quote(environ.get('PATH_INFO', ''))[0] == '/'):
101            #skip this if it is only a slash
102            pass
103
104        elif quote(environ.get('SCRIPT_NAME', ''))[0:2] == '//':
105            url += quote(environ.get('SCRIPT_NAME', ''))[1:]
106
107        else:
108            url += quote(environ.get('SCRIPT_NAME', ''))
109
110        url += quote(environ.get('PATH_INFO', ''))
111
112    if query_string:
113        if environ.get('QUERY_STRING'):
114            url += '?' + environ['QUERY_STRING']
115
116    return url
117
118def _parse_qs(qs):
119    pairs = (s2 for s1 in qs.split('&') for s2 in s1.split(';'))
120    retval = odict()
121
122    for name_value in pairs:
123        if name_value is None or len(name_value) == 0:
124            continue
125        nv = name_value.split('=', 1)
126
127        if len(nv) != 2:
128            # Handle case of a control-name with no equal sign
129            nv.append(None)
130
131        name = unquote(nv[0].replace('+', ' '))
132
133        value = None
134        if nv[1] is not None:
135            value = unquote(nv[1].replace('+', ' '))
136
137        l = retval.get(name, None)
138        if l is None:
139            l = retval[name] = []
140        l.append(value)
141
142    return retval
143
144
145def _get_http_headers(req_env):
146    retval = {}
147
148    for k, v in req_env.items():
149        if k.startswith("HTTP_"):
150            key = k[5:].lower()
151            val = [v]
152            retval[key]= val
153            logger.debug("Add http header %r = %r", key, val)
154
155    return retval
156
157
158def _gen_http_headers(headers):
159    retval = []
160
161    for k,v in headers.items():
162        if isinstance(v, (list, tuple)):
163            for v2 in v:
164                retval.append((k, v2))
165        else:
166            retval.append((k, v))
167
168    return retval
169
170
171class WsgiTransportContext(HttpTransportContext):
172    """The class that is used in the transport attribute of the
173    :class:`WsgiMethodContext` class."""
174
175    def __init__(self, parent, transport, req_env, content_type):
176        super(WsgiTransportContext, self).__init__(parent, transport,
177                                                          req_env, content_type)
178
179        self.req_env = self.req
180        """WSGI Request environment"""
181
182        self.req_method = req_env.get('REQUEST_METHOD', None)
183        """HTTP Request verb, as a convenience to users."""
184
185        self.headers = _get_http_headers(self.req_env)
186
187    def get_path(self):
188        return self.req_env['PATH_INFO']
189
190    def get_path_and_qs(self):
191        retval = quote(self.req_env.get('PATH_INFO', ''))
192        qs = self.req_env.get('QUERY_STRING', None)
193        if qs is not None:
194            retval += '?' + qs
195        return retval
196
197    def get_cookie(self, key):
198        cookie_string = self.req_env.get('HTTP_COOKIE', None)
199        if cookie_string is None:
200            return
201
202        cookie = SimpleCookie()
203        cookie.load(cookie_string)
204
205        return cookie.get(key, None).value
206
207    def get_request_method(self):
208        return self.req['REQUEST_METHOD'].upper()
209
210    def get_request_content_type(self):
211        return self.req.get("CONTENT_TYPE", None)
212
213    def get_peer(self):
214        addr, port = address_parser.get_ip(self.req),\
215                                               address_parser.get_port(self.req)
216
217        if address_parser.is_valid_ipv4(addr):
218            return Address(type=Address.TCP4, host=addr, port=port)
219
220        if address_parser.is_valid_ipv6(addr):
221            return Address(type=Address.TCP6, host=addr, port=port)
222
223
224class WsgiMethodContext(HttpMethodContext):
225    """The WSGI-Specific method context. WSGI-Specific information is stored in
226    the transport attribute using the :class:`WsgiTransportContext` class.
227    """
228
229    TransportContext = None
230    HttpTransportContext = WsgiTransportContext
231
232
233class WsgiApplication(HttpBase):
234    """A `PEP-3333 <http://www.python.org/dev/peps/pep-3333>`_
235    compliant callable class.
236
237    If you want to have a hard-coded URL in the wsdl document, this is how to do
238    it: ::
239
240        wsgi_app = WsgiApplication(...)
241        wsgi_app.doc.wsdl11.build_interface_document("http://example.com")
242
243    This is not strictly necessary -- if you don't do this, Spyne will get the
244    URL from the first request, build the wsdl on-the-fly and cache it as a
245    string in memory for later requests. However, if you want to make sure
246    you only have this url on the WSDL, this is how to do it. Note that if
247    your client takes the information in the Wsdl document seriously (not all
248    do), all requests will go to the designated url above even when you get the
249    Wsdl from another location, which can make testing a bit difficult. Use in
250    moderation.
251
252    Supported events:
253        * ``wsdl``
254            Called right before the wsdl data is returned to the client.
255
256        * ``wsdl_exception``
257            Called right after an exception is thrown during wsdl generation.
258            The exception object is stored in ctx.transport.wsdl_error
259            attribute.
260
261        * ``wsgi_call``
262            Called first when the incoming http request is identified as a rpc
263            request.
264
265        * ``wsgi_return``
266            Called right before the output stream is returned to the WSGI
267            handler.
268
269        * ``wsgi_exception``
270            Called right before returning the exception to the client.
271
272        * ``wsgi_close``
273            Called after the whole data has been returned to the client. It's
274            called both from success and error cases.
275    """
276
277    def __init__(self, app, chunked=True, max_content_length=2 * 1024 * 1024,
278                                                         block_length=8 * 1024):
279        super(WsgiApplication, self).__init__(app, chunked, max_content_length,
280                                                                   block_length)
281
282        self._mtx_build_interface_document = threading.Lock()
283
284        self._wsdl = None
285        if self.doc.wsdl11 is not None:
286            self._wsdl = self.doc.wsdl11.get_interface_document()
287
288    def __call__(self, req_env, start_response, wsgi_url=None):
289        """This method conforms to the WSGI spec for callable wsgi applications
290        (PEP 333). It looks in environ['wsgi.input'] for a fully formed rpc
291        message envelope, will deserialize the request parameters and call the
292        method on the object returned by the get_handler() method.
293        """
294
295        url = wsgi_url
296        if url is None:
297            url = _reconstruct_url(req_env).split('.wsdl')[0]
298
299        if self.is_wsdl_request(req_env):
300            # Format the url for location
301            url = url.split('?')[0].split('.wsdl')[0]
302            return self.handle_wsdl_request(req_env, start_response, url)
303
304        else:
305            return self.handle_rpc(req_env, start_response)
306
307    def is_wsdl_request(self, req_env):
308        # Get the wsdl for the service. Assume path_info matches pattern:
309        # /stuff/stuff/stuff/serviceName.wsdl or
310        # /stuff/stuff/stuff/serviceName/?wsdl
311
312        return (
313            req_env['REQUEST_METHOD'].upper() == 'GET'
314            and (
315                (
316                    'QUERY_STRING' in req_env
317                    and req_env['QUERY_STRING'].split('=')[0].lower() == 'wsdl'
318                )
319                or req_env['PATH_INFO'].endswith('.wsdl')
320            )
321        )
322
323    def handle_wsdl_request(self, req_env, start_response, url):
324        ctx = WsgiMethodContext(self, req_env, 'text/xml; charset=utf-8')
325
326        if self.doc.wsdl11 is None:
327            start_response(HTTP_404,
328                                  _gen_http_headers(ctx.transport.resp_headers))
329            return [HTTP_404]
330
331        if self._wsdl is None:
332            self._wsdl = self.doc.wsdl11.get_interface_document()
333
334        ctx.transport.wsdl = self._wsdl
335
336        if ctx.transport.wsdl is None:
337            try:
338                self._mtx_build_interface_document.acquire()
339
340                ctx.transport.wsdl = self._wsdl
341
342                if ctx.transport.wsdl is None:
343                    self.doc.wsdl11.build_interface_document(url)
344                    ctx.transport.wsdl = self._wsdl = \
345                                        self.doc.wsdl11.get_interface_document()
346
347            except Exception as e:
348                logger.exception(e)
349                ctx.transport.wsdl_error = e
350
351                self.event_manager.fire_event('wsdl_exception', ctx)
352
353                start_response(HTTP_500,
354                                  _gen_http_headers(ctx.transport.resp_headers))
355
356                return [HTTP_500]
357
358            finally:
359                self._mtx_build_interface_document.release()
360
361        self.event_manager.fire_event('wsdl', ctx)
362
363        ctx.transport.resp_headers['Content-Length'] = \
364                                                    str(len(ctx.transport.wsdl))
365        start_response(HTTP_200, _gen_http_headers(ctx.transport.resp_headers))
366
367        retval = ctx.transport.wsdl
368
369        ctx.close()
370
371        return [retval]
372
373    def handle_error(self, p_ctx, others, error, start_response):
374        """Serialize errors to an iterable of strings and return them.
375
376        :param p_ctx: Primary (non-aux) context.
377        :param others: List if auxiliary contexts (can be empty).
378        :param error: One of ctx.{in,out}_error.
379        :param start_response: See the WSGI spec for more info.
380        """
381
382        if p_ctx.transport.resp_code is None:
383            p_ctx.transport.resp_code = \
384                p_ctx.out_protocol.fault_to_http_response_code(error)
385
386        self.get_out_string(p_ctx)
387
388        # consume the generator to get the length
389        p_ctx.out_string = list(p_ctx.out_string)
390
391        p_ctx.transport.resp_headers['Content-Length'] = \
392                                    str(sum((len(s) for s in p_ctx.out_string)))
393        self.event_manager.fire_event('wsgi_exception', p_ctx)
394
395        start_response(p_ctx.transport.resp_code,
396                                _gen_http_headers(p_ctx.transport.resp_headers))
397
398        try:
399            process_contexts(self, others, p_ctx, error=error)
400        except Exception as e:
401            # Report but ignore any exceptions from auxiliary methods.
402            logger.exception(e)
403
404        return chain(p_ctx.out_string, self.__finalize(p_ctx))
405
406    def handle_rpc(self, req_env, start_response):
407        initial_ctx = WsgiMethodContext(self, req_env,
408                                                self.app.out_protocol.mime_type)
409
410        self.event_manager.fire_event('wsgi_call', initial_ctx)
411        initial_ctx.in_string, in_string_charset = \
412                                        self.__reconstruct_wsgi_request(req_env)
413
414        contexts = self.generate_contexts(initial_ctx, in_string_charset)
415        p_ctx, others = contexts[0], contexts[1:]
416
417        # TODO: rate limiting
418        p_ctx.active = True
419
420        if p_ctx.in_error:
421            return self.handle_error(p_ctx, others, p_ctx.in_error,
422                                                                 start_response)
423
424        self.get_in_object(p_ctx)
425        if p_ctx.in_error:
426            logger.error(p_ctx.in_error)
427            return self.handle_error(p_ctx, others, p_ctx.in_error,
428                                                                 start_response)
429
430        self.get_out_object(p_ctx)
431        if p_ctx.out_error:
432            return self.handle_error(p_ctx, others, p_ctx.out_error,
433                                                                 start_response)
434
435        assert p_ctx.out_object is not None
436        g = next(iter(p_ctx.out_object))
437        is_generator = len(p_ctx.out_object) == 1 and isgenerator(g)
438
439        # if the out_object is a generator function, this hack makes the user
440        # code run until first yield, which lets it set response headers and
441        # whatnot before calling start_response. It's important to run this
442        # here before serialization as the user function can also set output
443        # protocol. Is there a better way?
444        if is_generator:
445            first_obj = next(g)
446            p_ctx.out_object = ( chain((first_obj,), g), )
447
448        if p_ctx.transport.resp_code is None:
449            p_ctx.transport.resp_code = HTTP_200
450
451        try:
452            self.get_out_string(p_ctx)
453
454        except Exception as e:
455            logger.exception(e)
456            p_ctx.out_error = Fault('Server', get_fault_string_from_exception(e))
457            return self.handle_error(p_ctx, others, p_ctx.out_error,
458                                                                 start_response)
459
460
461        if isinstance(p_ctx.out_protocol, HttpRpc) and \
462                                               p_ctx.out_header_doc is not None:
463            p_ctx.transport.resp_headers.update(p_ctx.out_header_doc)
464
465        if p_ctx.descriptor and p_ctx.descriptor.mtom:
466            # when there is more than one return type, the result is
467            # encapsulated inside a list. when there's just one, the result
468            # is returned in a non-encapsulated form. the apply_mtom always
469            # expects the objects to be inside an iterable, hence the
470            # following test.
471            out_type_info = p_ctx.descriptor.out_message._type_info
472            if len(out_type_info) == 1:
473                p_ctx.out_object = [p_ctx.out_object]
474
475            p_ctx.transport.resp_headers, p_ctx.out_string = apply_mtom(
476                    p_ctx.transport.resp_headers, p_ctx.out_string,
477                    p_ctx.descriptor.out_message._type_info.values(),
478                    p_ctx.out_object,
479                )
480
481        self.event_manager.fire_event('wsgi_return', p_ctx)
482
483        if self.chunked:
484            # the user has not set a content-length, so we delete it as the
485            # input is just an iterable.
486            if 'Content-Length' in p_ctx.transport.resp_headers:
487                del p_ctx.transport.resp_headers['Content-Length']
488        else:
489            p_ctx.out_string = [''.join(p_ctx.out_string)]
490
491        try:
492            len(p_ctx.out_string)
493
494            p_ctx.transport.resp_headers['Content-Length'] = \
495                                    str(sum([len(a) for a in p_ctx.out_string]))
496        except TypeError:
497            pass
498
499        start_response(p_ctx.transport.resp_code,
500                                _gen_http_headers(p_ctx.transport.resp_headers))
501
502        retval = chain(p_ctx.out_string, self.__finalize(p_ctx))
503
504        try:
505            process_contexts(self, others, p_ctx, error=None)
506        except Exception as e:
507            # Report but ignore any exceptions from auxiliary methods.
508            logger.exception(e)
509
510        return retval
511
512    def __finalize(self, p_ctx):
513        p_ctx.close()
514        self.event_manager.fire_event('wsgi_close', p_ctx)
515
516        return ()
517
518    def __reconstruct_wsgi_request(self, http_env):
519        """Reconstruct http payload using information in the http header."""
520
521        content_type = http_env.get("CONTENT_TYPE")
522        charset = None
523        if content_type is not None:
524            # fyi, here's what the parse_header function returns:
525            # >>> import cgi; cgi.parse_header("text/xml; charset=utf-8")
526            # ('text/xml', {'charset': 'utf-8'})
527            content_type = cgi.parse_header(content_type)
528            charset = content_type[1].get('charset', None)
529
530        return self.__wsgi_input_to_iterable(http_env), charset
531
532    def __wsgi_input_to_iterable(self, http_env):
533        istream = http_env.get('wsgi.input')
534
535        length = str(http_env.get('CONTENT_LENGTH', self.max_content_length))
536        if len(length) == 0:
537            length = 0
538        else:
539            length = int(length)
540
541        if length > self.max_content_length:
542            raise RequestTooLongError()
543        bytes_read = 0
544
545        while bytes_read < length:
546            bytes_to_read = min(self.block_length, length - bytes_read)
547
548            if bytes_to_read + bytes_read > self.max_content_length:
549                raise RequestTooLongError()
550
551            data = istream.read(bytes_to_read)
552            if data is None or len(data) == 0:
553                break
554
555            bytes_read += len(data)
556
557            yield data
558
559    def decompose_incoming_envelope(self, prot, ctx, message):
560        """This function is only called by the HttpRpc protocol to have the wsgi
561        environment parsed into ``ctx.in_body_doc`` and ``ctx.in_header_doc``.
562        """
563
564        params = {}
565        wsgi_env = ctx.in_document
566
567        if self.has_patterns:
568            # http://legacy.python.org/dev/peps/pep-0333/#url-reconstruction
569            domain = wsgi_env.get('HTTP_HOST', None)
570            if domain is None:
571                domain = wsgi_env['SERVER_NAME']
572            else:
573                domain = domain.partition(':')[0] # strip port info
574
575            params = self.match_pattern(ctx,
576                    wsgi_env.get('REQUEST_METHOD', ''),
577                    wsgi_env.get('PATH_INFO', ''),
578                    domain,
579                )
580
581        if ctx.method_request_string is None:
582            ctx.method_request_string = '{%s}%s' % (
583                                    prot.app.interface.get_tns(),
584                                    wsgi_env['PATH_INFO'].split('/')[-1])
585
586        logger.debug("%sMethod name: %r%s" % (LIGHT_GREEN,
587                                          ctx.method_request_string, END_COLOR))
588
589        ctx.in_header_doc = ctx.transport.headers
590        ctx.in_body_doc = _parse_qs(wsgi_env['QUERY_STRING'])
591
592        for k, v in params.items():
593             if k in ctx.in_body_doc:
594                 ctx.in_body_doc[k].extend(v)
595             else:
596                 ctx.in_body_doc[k] = list(v)
597
598        verb = wsgi_env['REQUEST_METHOD'].upper()
599        if verb in ('POST', 'PUT', 'PATCH'):
600            stream, form, files = parse_form_data(wsgi_env,
601                                             stream_factory=prot.stream_factory)
602
603            for k, v in form.lists():
604                val = ctx.in_body_doc.get(k, [])
605                val.extend(v)
606                ctx.in_body_doc[k] = val
607
608            for k, v in files.items():
609                val = ctx.in_body_doc.get(k, [])
610
611                mime_type = v.headers.get('Content-Type',
612                                                     'application/octet-stream')
613
614                path = getattr(v.stream, 'name', None)
615                if path is None:
616                    val.append(File.Value(name=v.filename, type=mime_type,
617                                                    data=[v.stream.getvalue()]))
618                else:
619                    v.stream.seek(0)
620                    val.append(File.Value(name=v.filename, type=mime_type,
621                                                    path=path, handle=v.stream))
622
623                ctx.in_body_doc[k] = val
624
625            for k, v in ctx.in_body_doc.items():
626                if v == ['']:
627                    ctx.in_body_doc[k] = [None]
628