1import logging
2import sys
3import weakref
4
5import webob
6
7from wsme.exc import ClientSideError, UnknownFunction
8from wsme.protocol import getprotocol
9from wsme.rest import scan_api
10import wsme.api
11import wsme.types
12
13log = logging.getLogger(__name__)
14
15html_body = """
16<html>
17<head>
18  <style type='text/css'>
19    %(css)s
20  </style>
21</head>
22<body>
23%(content)s
24</body>
25</html>
26"""
27
28
29def default_prepare_response_body(request, results):
30    r = None
31    sep = None
32    for value in results:
33        if sep is None:
34            if isinstance(value, str):
35                sep = '\n'
36                r = ''
37            else:
38                sep = b'\n'
39                r = b''
40        else:
41            r += sep
42        r += value
43    return r
44
45
46class DummyTransaction:
47    def commit(self):
48        pass
49
50    def abort(self):
51        pass
52
53
54class WSRoot(object):
55    """
56    Root controller for webservices.
57
58    :param protocols: A list of protocols to enable (see :meth:`addprotocol`)
59    :param webpath: The web path where the webservice is published.
60
61    :type  transaction: A `transaction
62                        <http://pypi.python.org/pypi/transaction>`_-like
63                        object or ``True``.
64    :param transaction: If specified, a transaction will be created and
65                        handled on a per-call base.
66
67                        This option *can* be enabled along with `repoze.tm2
68                        <http://pypi.python.org/pypi/repoze.tm2>`_
69                        (it will only make it void).
70
71                        If ``True``, the default :mod:`transaction`
72                        module will be imported and used.
73
74    """
75    __registry__ = wsme.types.registry
76
77    def __init__(self, protocols=[], webpath='', transaction=None,
78                 scan_api=scan_api):
79        self._debug = True
80        self._webpath = webpath
81        self.protocols = []
82        self._scan_api = scan_api
83
84        self._transaction = transaction
85        if self._transaction is True:
86            import transaction
87            self._transaction = transaction
88
89        for protocol in protocols:
90            self.addprotocol(protocol)
91
92        self._api = None
93
94    def wsgiapp(self):
95        """Returns a wsgi application"""
96        from webob.dec import wsgify
97        return wsgify(self._handle_request)
98
99    def begin(self):
100        if self._transaction:
101            return self._transaction.begin()
102        else:
103            return DummyTransaction()
104
105    def addprotocol(self, protocol, **options):
106        """
107        Enable a new protocol on the controller.
108
109        :param protocol: A registered protocol name or an instance
110                         of a protocol.
111        """
112        if isinstance(protocol, str):
113            protocol = getprotocol(protocol, **options)
114        self.protocols.append(protocol)
115        protocol.root = weakref.proxy(self)
116
117    def getapi(self):
118        """
119        Returns the api description.
120
121        :rtype: list of (path, :class:`FunctionDefinition`)
122        """
123        if self._api is None:
124            self._api = [
125                (path, f, f._wsme_definition, args)
126                for path, f, args in self._scan_api(self)
127            ]
128            for path, f, fdef, args in self._api:
129                fdef.resolve_types(self.__registry__)
130        return [
131            (path, fdef)
132            for path, f, fdef, args in self._api
133        ]
134
135    def _get_protocol(self, name):
136        for protocol in self.protocols:
137            if protocol.name == name:
138                return protocol
139
140    def _select_protocol(self, request):
141        log.debug("Selecting a protocol for the following request :\n"
142                  "headers: %s\nbody: %s", request.headers.items(),
143                  request.content_length and (
144                      request.content_length > 512 and
145                      request.body[:512] or
146                      request.body) or '')
147        protocol = None
148        error = ClientSideError(status_code=406)
149        path = str(request.path)
150        assert path.startswith(self._webpath)
151        path = path[len(self._webpath) + 1:]
152        if 'wsmeproto' in request.params:
153            return self._get_protocol(request.params['wsmeproto'])
154        else:
155
156            for p in self.protocols:
157                try:
158                    if p.accept(request):
159                        protocol = p
160                        break
161                except ClientSideError as e:
162                    error = e
163            # If we could not select a protocol, we raise the last exception
164            # that we got, or the default one.
165            if not protocol:
166                raise error
167        return protocol
168
169    def _do_call(self, protocol, context):
170        request = context.request
171        request.calls.append(context)
172        try:
173            if context.path is None:
174                context.path = protocol.extract_path(context)
175
176            if context.path is None:
177                raise ClientSideError(
178                    'The %s protocol was unable to extract a function '
179                    'path from the request' % protocol.name)
180
181            context.func, context.funcdef, args = \
182                self._lookup_function(context.path)
183            kw = protocol.read_arguments(context)
184            args = list(args)
185
186            txn = self.begin()
187            try:
188                result = context.func(*args, **kw)
189                txn.commit()
190            except Exception:
191                txn.abort()
192                raise
193
194            else:
195                # TODO make sure result type == a._wsme_definition.return_type
196                return protocol.encode_result(context, result)
197
198        except Exception as e:
199            infos = wsme.api.format_exception(sys.exc_info(), self._debug)
200            if isinstance(e, ClientSideError):
201                request.client_errorcount += 1
202                request.client_last_status_code = e.code
203            else:
204                request.server_errorcount += 1
205            return protocol.encode_error(context, infos)
206
207    def find_route(self, path):
208        for p in self.protocols:
209            for routepath, func in p.iter_routes():
210                if path.startswith(routepath):
211                    return routepath, func
212        return None, None
213
214    def _handle_request(self, request):
215        res = webob.Response()
216        res_content_type = None
217
218        path = request.path
219        if path.startswith(self._webpath):
220            path = path[len(self._webpath):]
221        routepath, func = self.find_route(path)
222        if routepath:
223            content = func()
224            if isinstance(content, str):
225                res.text = content
226            elif isinstance(content, bytes):
227                res.body = content
228            res.content_type = func._cfg['content-type']
229            return res
230
231        try:
232            msg = None
233            error_status = 500
234            protocol = self._select_protocol(request)
235        except ClientSideError as e:
236            error_status = e.code
237            msg = e.faultstring
238            protocol = None
239        except Exception as e:
240            msg = ("Unexpected error while selecting protocol: %s" % str(e))
241            log.exception(msg)
242            protocol = None
243            error_status = 500
244
245        if protocol is None:
246            if not msg:
247                msg = ("None of the following protocols can handle this "
248                       "request : %s" % ','.join((
249                           p.name for p in self.protocols)))
250            res.status = error_status
251            res.content_type = 'text/plain'
252            try:
253                res.text = str(msg)
254            except TypeError:
255                res.text = msg
256            log.error(msg)
257            return res
258
259        request.calls = []
260        request.client_errorcount = 0
261        request.client_last_status_code = None
262        request.server_errorcount = 0
263
264        try:
265
266            context = None
267
268            if hasattr(protocol, 'prepare_response_body'):
269                prepare_response_body = protocol.prepare_response_body
270            else:
271                prepare_response_body = default_prepare_response_body
272
273            body = prepare_response_body(request, (
274                self._do_call(protocol, context)
275                for context in protocol.iter_calls(request)))
276
277            if isinstance(body, str):
278                res.text = body
279            else:
280                res.body = body
281
282            if len(request.calls) == 1:
283                if hasattr(protocol, 'get_response_status'):
284                    res.status = protocol.get_response_status(request)
285                else:
286                    if request.client_errorcount == 1:
287                        res.status = request.client_last_status_code
288                    elif request.client_errorcount:
289                        res.status = 400
290                    elif request.server_errorcount:
291                        res.status = 500
292                    else:
293                        res.status = 200
294            else:
295                res.status = protocol.get_response_status(request)
296                res_content_type = protocol.get_response_contenttype(request)
297        except ClientSideError as e:
298            request.server_errorcount += 1
299            res.status = e.code
300            res.text = e.faultstring
301        except Exception:
302            infos = wsme.api.format_exception(sys.exc_info(), self._debug)
303            request.server_errorcount += 1
304            res.text = protocol.encode_error(context, infos)
305            res.status = 500
306
307        if res_content_type is None:
308            # Attempt to correctly guess what content-type we should return.
309            ctypes = [ct for ct in protocol.content_types if ct]
310            if ctypes:
311                try:
312                    offers = request.accept.acceptable_offers(ctypes)
313                    res_content_type = offers[0][0]
314                except IndexError:
315                    res_content_type = None
316
317        # If not we will attempt to convert the body to an accepted
318        # output format.
319        if res_content_type is None:
320            if "text/html" in request.accept:
321                res.text = self._html_format(res.body, protocol.content_types)
322                res_content_type = "text/html"
323
324        # TODO should we consider the encoding asked by
325        # the web browser ?
326        res.headers['Content-Type'] = "%s; charset=UTF-8" % res_content_type
327
328        return res
329
330    def _lookup_function(self, path):
331        if not self._api:
332            self.getapi()
333
334        for fpath, f, fdef, args in self._api:
335            if path == fpath:
336                return f, fdef, args
337        raise UnknownFunction('/'.join(path))
338
339    def _html_format(self, content, content_types):
340        try:
341            from pygments import highlight
342            from pygments.lexers import get_lexer_for_mimetype
343            from pygments.formatters import HtmlFormatter
344
345            lexer = None
346            for ct in content_types:
347                try:
348                    lexer = get_lexer_for_mimetype(ct)
349                    break
350                except Exception:
351                    pass
352
353            if lexer is None:
354                raise ValueError("No lexer found")
355            formatter = HtmlFormatter()
356            return html_body % dict(
357                css=formatter.get_style_defs(),
358                content=highlight(content, lexer, formatter).encode('utf8'))
359        except Exception as e:
360            log.warning(
361                "Could not pygment the content because of the following "
362                "error :\n%s" % e)
363            return html_body % dict(
364                css='',
365                content='<pre>%s</pre>' %
366                    content.replace(b'>', b'&gt;')
367                           .replace(b'<', b'&lt;'))
368