1import logging 2import sys 3import weakref 4 5import webob 6 7from wsme.exc import ClientSideError, UnknownFunction 8from wsme.protocol import getprotocol 9from wsme.rest import scan_api 10import wsme.api 11import wsme.types 12 13log = logging.getLogger(__name__) 14 15html_body = """ 16<html> 17<head> 18 <style type='text/css'> 19 %(css)s 20 </style> 21</head> 22<body> 23%(content)s 24</body> 25</html> 26""" 27 28 29def default_prepare_response_body(request, results): 30 r = None 31 sep = None 32 for value in results: 33 if sep is None: 34 if isinstance(value, str): 35 sep = '\n' 36 r = '' 37 else: 38 sep = b'\n' 39 r = b'' 40 else: 41 r += sep 42 r += value 43 return r 44 45 46class DummyTransaction: 47 def commit(self): 48 pass 49 50 def abort(self): 51 pass 52 53 54class WSRoot(object): 55 """ 56 Root controller for webservices. 57 58 :param protocols: A list of protocols to enable (see :meth:`addprotocol`) 59 :param webpath: The web path where the webservice is published. 60 61 :type transaction: A `transaction 62 <http://pypi.python.org/pypi/transaction>`_-like 63 object or ``True``. 64 :param transaction: If specified, a transaction will be created and 65 handled on a per-call base. 66 67 This option *can* be enabled along with `repoze.tm2 68 <http://pypi.python.org/pypi/repoze.tm2>`_ 69 (it will only make it void). 70 71 If ``True``, the default :mod:`transaction` 72 module will be imported and used. 73 74 """ 75 __registry__ = wsme.types.registry 76 77 def __init__(self, protocols=[], webpath='', transaction=None, 78 scan_api=scan_api): 79 self._debug = True 80 self._webpath = webpath 81 self.protocols = [] 82 self._scan_api = scan_api 83 84 self._transaction = transaction 85 if self._transaction is True: 86 import transaction 87 self._transaction = transaction 88 89 for protocol in protocols: 90 self.addprotocol(protocol) 91 92 self._api = None 93 94 def wsgiapp(self): 95 """Returns a wsgi application""" 96 from webob.dec import wsgify 97 return wsgify(self._handle_request) 98 99 def begin(self): 100 if self._transaction: 101 return self._transaction.begin() 102 else: 103 return DummyTransaction() 104 105 def addprotocol(self, protocol, **options): 106 """ 107 Enable a new protocol on the controller. 108 109 :param protocol: A registered protocol name or an instance 110 of a protocol. 111 """ 112 if isinstance(protocol, str): 113 protocol = getprotocol(protocol, **options) 114 self.protocols.append(protocol) 115 protocol.root = weakref.proxy(self) 116 117 def getapi(self): 118 """ 119 Returns the api description. 120 121 :rtype: list of (path, :class:`FunctionDefinition`) 122 """ 123 if self._api is None: 124 self._api = [ 125 (path, f, f._wsme_definition, args) 126 for path, f, args in self._scan_api(self) 127 ] 128 for path, f, fdef, args in self._api: 129 fdef.resolve_types(self.__registry__) 130 return [ 131 (path, fdef) 132 for path, f, fdef, args in self._api 133 ] 134 135 def _get_protocol(self, name): 136 for protocol in self.protocols: 137 if protocol.name == name: 138 return protocol 139 140 def _select_protocol(self, request): 141 log.debug("Selecting a protocol for the following request :\n" 142 "headers: %s\nbody: %s", request.headers.items(), 143 request.content_length and ( 144 request.content_length > 512 and 145 request.body[:512] or 146 request.body) or '') 147 protocol = None 148 error = ClientSideError(status_code=406) 149 path = str(request.path) 150 assert path.startswith(self._webpath) 151 path = path[len(self._webpath) + 1:] 152 if 'wsmeproto' in request.params: 153 return self._get_protocol(request.params['wsmeproto']) 154 else: 155 156 for p in self.protocols: 157 try: 158 if p.accept(request): 159 protocol = p 160 break 161 except ClientSideError as e: 162 error = e 163 # If we could not select a protocol, we raise the last exception 164 # that we got, or the default one. 165 if not protocol: 166 raise error 167 return protocol 168 169 def _do_call(self, protocol, context): 170 request = context.request 171 request.calls.append(context) 172 try: 173 if context.path is None: 174 context.path = protocol.extract_path(context) 175 176 if context.path is None: 177 raise ClientSideError( 178 'The %s protocol was unable to extract a function ' 179 'path from the request' % protocol.name) 180 181 context.func, context.funcdef, args = \ 182 self._lookup_function(context.path) 183 kw = protocol.read_arguments(context) 184 args = list(args) 185 186 txn = self.begin() 187 try: 188 result = context.func(*args, **kw) 189 txn.commit() 190 except Exception: 191 txn.abort() 192 raise 193 194 else: 195 # TODO make sure result type == a._wsme_definition.return_type 196 return protocol.encode_result(context, result) 197 198 except Exception as e: 199 infos = wsme.api.format_exception(sys.exc_info(), self._debug) 200 if isinstance(e, ClientSideError): 201 request.client_errorcount += 1 202 request.client_last_status_code = e.code 203 else: 204 request.server_errorcount += 1 205 return protocol.encode_error(context, infos) 206 207 def find_route(self, path): 208 for p in self.protocols: 209 for routepath, func in p.iter_routes(): 210 if path.startswith(routepath): 211 return routepath, func 212 return None, None 213 214 def _handle_request(self, request): 215 res = webob.Response() 216 res_content_type = None 217 218 path = request.path 219 if path.startswith(self._webpath): 220 path = path[len(self._webpath):] 221 routepath, func = self.find_route(path) 222 if routepath: 223 content = func() 224 if isinstance(content, str): 225 res.text = content 226 elif isinstance(content, bytes): 227 res.body = content 228 res.content_type = func._cfg['content-type'] 229 return res 230 231 try: 232 msg = None 233 error_status = 500 234 protocol = self._select_protocol(request) 235 except ClientSideError as e: 236 error_status = e.code 237 msg = e.faultstring 238 protocol = None 239 except Exception as e: 240 msg = ("Unexpected error while selecting protocol: %s" % str(e)) 241 log.exception(msg) 242 protocol = None 243 error_status = 500 244 245 if protocol is None: 246 if not msg: 247 msg = ("None of the following protocols can handle this " 248 "request : %s" % ','.join(( 249 p.name for p in self.protocols))) 250 res.status = error_status 251 res.content_type = 'text/plain' 252 try: 253 res.text = str(msg) 254 except TypeError: 255 res.text = msg 256 log.error(msg) 257 return res 258 259 request.calls = [] 260 request.client_errorcount = 0 261 request.client_last_status_code = None 262 request.server_errorcount = 0 263 264 try: 265 266 context = None 267 268 if hasattr(protocol, 'prepare_response_body'): 269 prepare_response_body = protocol.prepare_response_body 270 else: 271 prepare_response_body = default_prepare_response_body 272 273 body = prepare_response_body(request, ( 274 self._do_call(protocol, context) 275 for context in protocol.iter_calls(request))) 276 277 if isinstance(body, str): 278 res.text = body 279 else: 280 res.body = body 281 282 if len(request.calls) == 1: 283 if hasattr(protocol, 'get_response_status'): 284 res.status = protocol.get_response_status(request) 285 else: 286 if request.client_errorcount == 1: 287 res.status = request.client_last_status_code 288 elif request.client_errorcount: 289 res.status = 400 290 elif request.server_errorcount: 291 res.status = 500 292 else: 293 res.status = 200 294 else: 295 res.status = protocol.get_response_status(request) 296 res_content_type = protocol.get_response_contenttype(request) 297 except ClientSideError as e: 298 request.server_errorcount += 1 299 res.status = e.code 300 res.text = e.faultstring 301 except Exception: 302 infos = wsme.api.format_exception(sys.exc_info(), self._debug) 303 request.server_errorcount += 1 304 res.text = protocol.encode_error(context, infos) 305 res.status = 500 306 307 if res_content_type is None: 308 # Attempt to correctly guess what content-type we should return. 309 ctypes = [ct for ct in protocol.content_types if ct] 310 if ctypes: 311 try: 312 offers = request.accept.acceptable_offers(ctypes) 313 res_content_type = offers[0][0] 314 except IndexError: 315 res_content_type = None 316 317 # If not we will attempt to convert the body to an accepted 318 # output format. 319 if res_content_type is None: 320 if "text/html" in request.accept: 321 res.text = self._html_format(res.body, protocol.content_types) 322 res_content_type = "text/html" 323 324 # TODO should we consider the encoding asked by 325 # the web browser ? 326 res.headers['Content-Type'] = "%s; charset=UTF-8" % res_content_type 327 328 return res 329 330 def _lookup_function(self, path): 331 if not self._api: 332 self.getapi() 333 334 for fpath, f, fdef, args in self._api: 335 if path == fpath: 336 return f, fdef, args 337 raise UnknownFunction('/'.join(path)) 338 339 def _html_format(self, content, content_types): 340 try: 341 from pygments import highlight 342 from pygments.lexers import get_lexer_for_mimetype 343 from pygments.formatters import HtmlFormatter 344 345 lexer = None 346 for ct in content_types: 347 try: 348 lexer = get_lexer_for_mimetype(ct) 349 break 350 except Exception: 351 pass 352 353 if lexer is None: 354 raise ValueError("No lexer found") 355 formatter = HtmlFormatter() 356 return html_body % dict( 357 css=formatter.get_style_defs(), 358 content=highlight(content, lexer, formatter).encode('utf8')) 359 except Exception as e: 360 log.warning( 361 "Could not pygment the content because of the following " 362 "error :\n%s" % e) 363 return html_body % dict( 364 css='', 365 content='<pre>%s</pre>' % 366 content.replace(b'>', b'>') 367 .replace(b'<', b'<')) 368