1import functools 2import operator 3try: 4 import simplejson as json 5except ImportError: 6 import json 7 8from flask import Blueprint 9from flask import Response 10from flask import abort 11from flask import g 12from flask import redirect 13from flask import request 14from flask import session 15from flask import url_for 16from peewee import * 17from peewee import DJANGO_MAP 18 19from flask_peewee.filters import make_field_tree 20from flask_peewee.serializer import Deserializer 21from flask_peewee.serializer import Serializer 22from flask_peewee.utils import PaginatedQuery 23from flask_peewee.utils import get_object_or_404 24from flask_peewee.utils import slugify 25from flask_peewee._compat import reduce 26 27 28class Authentication(object): 29 def __init__(self, protected_methods=None): 30 if protected_methods is None: 31 protected_methods = ['POST', 'PUT', 'DELETE'] 32 33 self.protected_methods = protected_methods 34 35 def authorize(self): 36 if request.method in self.protected_methods: 37 return False 38 39 return True 40 41 42class APIKeyAuthentication(Authentication): 43 """ 44 Requires a model that has at least two fields, "key" and "secret", which will 45 be searched for when authing a request. 46 """ 47 key_field = 'key' 48 secret_field = 'secret' 49 50 def __init__(self, model, protected_methods=None): 51 super(APIKeyAuthentication, self).__init__(protected_methods) 52 self.model = model 53 self._key_field = model._meta.fields[self.key_field] 54 self._secret_field = model._meta.fields[self.secret_field] 55 56 def get_query(self): 57 return self.model.select() 58 59 def get_key(self, k, s): 60 try: 61 return self.get_query().where( 62 self._key_field==k, 63 self._secret_field==s 64 ).get() 65 except self.model.DoesNotExist: 66 pass 67 68 def get_key_secret(self): 69 for search in [request.args, request.headers, request.form]: 70 if 'key' in search and 'secret' in search: 71 return search['key'], search['secret'] 72 return None, None 73 74 def authorize(self): 75 g.api_key = None 76 77 if request.method not in self.protected_methods: 78 return True 79 80 key, secret = self.get_key_secret() 81 if key or secret: 82 g.api_key = self.get_key(key, secret) 83 84 return g.api_key 85 86 87class UserAuthentication(Authentication): 88 def __init__(self, auth, protected_methods=None): 89 super(UserAuthentication, self).__init__(protected_methods) 90 self.auth = auth 91 92 def authorize(self): 93 g.user = None 94 95 if request.method not in self.protected_methods: 96 return True 97 98 basic_auth = request.authorization 99 if not basic_auth: 100 return False 101 102 g.user = self.auth.authenticate(basic_auth.username, basic_auth.password) 103 return g.user 104 105 106class AdminAuthentication(UserAuthentication): 107 def verify_user(self, user): 108 return user.admin 109 110 def authorize(self): 111 res = super(AdminAuthentication, self).authorize() 112 113 if res and g.user: 114 return self.verify_user(g.user) 115 return res 116 117 118class RestResource(object): 119 paginate_by = 20 120 value_transforms = {'False': False, 'false': False, 121 'True': True, 'true': True, 122 'None': None, 'none': None} 123 124 # serializing: dictionary of model -> field names to restrict output 125 fields = None 126 exclude = None 127 128 # exclude certian fields from being exposed as filters -- for related fields 129 # use "__" notation, e.g. user__password 130 filter_exclude = None 131 filter_fields = None 132 filter_recursive = True 133 134 # mapping of field name to resource class 135 include_resources = None 136 137 # delete behavior 138 delete_recursive = True 139 140 def __init__(self, rest_api, model, authentication, allowed_methods=None): 141 self.api = rest_api 142 self.model = model 143 self.pk = model._meta.primary_key 144 145 self.authentication = authentication 146 self.allowed_methods = allowed_methods or ['GET', 'POST', 'PUT', 'DELETE'] 147 148 self._fields = {self.model: self.fields or self.model._meta.sorted_field_names} 149 if self.exclude: 150 self._exclude = {self.model: self.exclude} 151 else: 152 self._exclude = {} 153 154 self._filter_fields = self.filter_fields or list(self.model._meta.sorted_field_names) 155 self._filter_exclude = self.filter_exclude or [] 156 157 self._resources = {} 158 159 # recurse into nested resources 160 if self.include_resources: 161 for field_name, resource in self.include_resources.items(): 162 field_obj = self.model._meta.fields[field_name] 163 resource_obj = resource(self.api, field_obj.rel_model, self.authentication, self.allowed_methods) 164 self._resources[field_name] = resource_obj 165 self._fields.update(resource_obj._fields) 166 self._exclude.update(resource_obj._exclude) 167 168 self._filter_fields.extend(['%s__%s' % (field_name, ff) for ff in resource_obj._filter_fields]) 169 self._filter_exclude.extend(['%s__%s' % (field_name, ff) for ff in resource_obj._filter_exclude]) 170 171 self._include_foreign_keys = False 172 else: 173 self._include_foreign_keys = True 174 175 self._field_tree = make_field_tree(self.model, self._filter_fields, self._filter_exclude, self.filter_recursive) 176 177 def authorize(self): 178 return self.authentication.authorize() 179 180 def get_api_name(self): 181 return slugify(self.model.__name__) 182 183 def get_url_name(self, name): 184 return '%s.%s_%s' % ( 185 self.api.blueprint.name, 186 self.get_api_name(), 187 name, 188 ) 189 190 def get_query(self): 191 return self.model.select() 192 193 def process_query(self, query): 194 raw_filters = {} 195 196 # clean and normalize the request parameters 197 for key in request.args: 198 orig_key = key 199 if key.startswith('-'): 200 negated = True 201 key = key[1:] 202 else: 203 negated = False 204 if '__' in key: 205 expr, op = key.rsplit('__', 1) 206 if op not in DJANGO_MAP: 207 expr = key 208 op = 'eq' 209 else: 210 expr = key 211 op = 'eq' 212 raw_filters.setdefault(expr, []) 213 raw_filters[expr].append((op, request.args.getlist(orig_key), negated)) 214 215 # do a breadth first search across the field tree created by filter_fields, 216 # searching for matching keys in the request parameters -- when found, 217 # filter the query accordingly 218 queue = [(self._field_tree, '')] 219 while queue: 220 node, prefix = queue.pop(0) 221 for field in node.fields: 222 filter_expr = '%s%s' % (prefix, field.name) 223 if filter_expr in raw_filters: 224 for op, arg_list, negated in raw_filters[filter_expr]: 225 clean_args = self.clean_arg_list(arg_list) 226 query = self.apply_filter(query, filter_expr, op, clean_args, negated) 227 228 for child_prefix, child_node in node.children.items(): 229 queue.append((child_node, prefix + child_prefix + '__')) 230 231 return query 232 233 def clean_arg_list(self, arg_list): 234 return [self.value_transforms.get(arg, arg) for arg in arg_list] 235 236 def apply_filter(self, query, expr, op, arg_list, negated): 237 query_expr = '%s__%s' % (expr, op) 238 constructor = lambda kwargs: negated and ~DQ(**kwargs) or DQ(**kwargs) 239 if op == 'in': 240 # in gives us a string format list '1,2,3,4' 241 # we have to turn it into a list before passing to 242 # the filter. 243 arg_list = [i.strip() for i in arg_list[0].split(',')] 244 return query.filter(constructor({query_expr: arg_list})) 245 elif len(arg_list) == 1: 246 return query.filter(constructor({query_expr: arg_list[0]})) 247 else: 248 query_clauses = [ 249 constructor({query_expr: val}) for val in arg_list] 250 return query.filter(reduce(operator.or_, query_clauses)) 251 252 def get_serializer(self): 253 return Serializer() 254 255 def get_deserializer(self): 256 return Deserializer() 257 258 def prepare_data(self, obj, data): 259 """ 260 Hook for modifying outgoing data 261 """ 262 return data 263 264 def serialize_object(self, obj): 265 s = self.get_serializer() 266 return self.prepare_data( 267 obj, s.serialize_object(obj, self._fields, self._exclude) 268 ) 269 270 def serialize_query(self, query): 271 s = self.get_serializer() 272 return [ 273 self.prepare_data(obj, s.serialize_object(obj, self._fields, self._exclude)) \ 274 for obj in query 275 ] 276 277 def deserialize_object(self, data, instance): 278 d = self.get_deserializer() 279 return d.deserialize_object(instance, data) 280 281 def response_forbidden(self): 282 return Response('Forbidden', 403) 283 284 def response_bad_method(self): 285 return Response('Unsupported method "%s"' % (request.method), 405) 286 287 def response_bad_request(self): 288 return Response('Bad request', 400) 289 290 def response(self, data): 291 return Response(json.dumps(data), mimetype='application/json') 292 293 def require_method(self, func, methods): 294 @functools.wraps(func) 295 def inner(*args, **kwargs): 296 if request.method not in methods: 297 return self.response_bad_method() 298 return func(*args, **kwargs) 299 return inner 300 301 def get_urls(self): 302 return ( 303 ('/', self.require_method(self.api_list, ['GET', 'POST'])), 304 ('/<pk>/', self.require_method(self.api_detail, ['GET', 'POST', 'PUT', 'DELETE'])), 305 ('/<pk>/delete/', self.require_method(self.post_delete, ['POST', 'DELETE'])), 306 ) 307 308 def check_get(self, obj=None): 309 return True 310 311 def check_post(self, obj=None): 312 return True 313 314 def check_put(self, obj): 315 return True 316 317 def check_delete(self, obj): 318 return True 319 320 def save_object(self, instance, raw_data): 321 instance.save() 322 return instance 323 324 def api_list(self): 325 if not getattr(self, 'check_%s' % request.method.lower())(): 326 return self.response_forbidden() 327 328 if request.method == 'GET': 329 return self.object_list() 330 elif request.method == 'POST': 331 return self.create() 332 333 def api_detail(self, pk, method=None): 334 obj = get_object_or_404(self.get_query(), self.pk==pk) 335 336 method = method or request.method 337 338 if not getattr(self, 'check_%s' % method.lower())(obj): 339 return self.response_forbidden() 340 341 if method == 'GET': 342 return self.object_detail(obj) 343 elif method in ('PUT', 'POST'): 344 return self.edit(obj) 345 elif method == 'DELETE': 346 return self.delete(obj) 347 348 def post_delete(self, pk): 349 return self.api_detail(pk, 'DELETE') 350 351 def apply_ordering(self, query): 352 ordering = request.args.get('ordering') or '' 353 if ordering: 354 desc, column = ordering.startswith('-'), ordering.lstrip('-') 355 if column in self.model._meta.fields: 356 field = self.model._meta.fields[column] 357 query = query.order_by(field.asc() if not desc else field.desc()) 358 359 return query 360 361 def get_request_metadata(self, paginated_query): 362 var = paginated_query.page_var 363 request_arguments = request.args.copy() 364 365 current_page = paginated_query.get_page() 366 next = previous = '' 367 368 if current_page > 1: 369 request_arguments[var] = current_page - 1 370 previous = url_for(self.get_url_name('api_list'), **request_arguments) 371 if current_page < paginated_query.get_pages(): 372 request_arguments[var] = current_page + 1 373 next = url_for(self.get_url_name('api_list'), **request_arguments) 374 375 return { 376 'model': self.get_api_name(), 377 'page': current_page, 378 'previous': previous, 379 'next': next, 380 } 381 382 def get_paginate_by(self): 383 try: 384 paginate_by = int(request.args.get('limit', self.paginate_by)) 385 except ValueError: 386 paginate_by = self.paginate_by 387 else: 388 if self.paginate_by: 389 paginate_by = min(paginate_by, self.paginate_by) # restrict 390 return paginate_by 391 392 def paginated_object_list(self, filtered_query): 393 paginate_by = self.get_paginate_by() 394 pq = PaginatedQuery(filtered_query, paginate_by) 395 meta_data = self.get_request_metadata(pq) 396 397 query_dict = self.serialize_query(pq.get_list()) 398 399 return self.response({ 400 'meta': meta_data, 401 'objects': query_dict, 402 }) 403 404 def object_list(self): 405 query = self.get_query() 406 query = self.apply_ordering(query) 407 408 # process any filters 409 query = self.process_query(query) 410 411 if self.paginate_by or 'limit' in request.args: 412 return self.paginated_object_list(query) 413 414 return self.response(self.serialize_query(query)) 415 416 def object_detail(self, obj): 417 return self.response(self.serialize_object(obj)) 418 419 def save_related_objects(self, instance, data): 420 for k, v in data.items(): 421 if k in self._resources and isinstance(v, dict): 422 rel_resource = self._resources[k] 423 rel_obj, rel_models = rel_resource.deserialize_object(v, getattr(instance, k)) 424 rel_resource.save_related_objects(rel_obj, v) 425 setattr(instance, k, rel_resource.save_object(rel_obj, v)) 426 427 def read_request_data(self): 428 if request.data: 429 return json.loads(request.data.decode('utf-8')) 430 elif request.form.get('data'): 431 return json.loads(request.form['data']) 432 else: 433 return dict(request.form) 434 435 def create(self): 436 try: 437 data = self.read_request_data() 438 except ValueError: 439 return self.response_bad_request() 440 441 obj, models = self.deserialize_object(data, self.model()) 442 443 self.save_related_objects(obj, data) 444 obj = self.save_object(obj, data) 445 446 return self.response(self.serialize_object(obj)) 447 448 def edit(self, obj): 449 try: 450 data = self.read_request_data() 451 except ValueError: 452 return self.response_bad_request() 453 454 obj, models = self.deserialize_object(data, obj) 455 456 self.save_related_objects(obj, data) 457 obj = self.save_object(obj, data) 458 459 return self.response(self.serialize_object(obj)) 460 461 def delete(self, obj): 462 res = obj.delete_instance(recursive=self.delete_recursive) 463 return self.response({'deleted': res}) 464 465 466class RestrictOwnerResource(RestResource): 467 # restrict PUT/DELETE to owner of an object, likewise apply owner to any 468 # incoming POSTs 469 owner_field = 'user' 470 471 def validate_owner(self, user, obj): 472 return user == getattr(obj, self.owner_field) 473 474 def set_owner(self, obj, user): 475 setattr(obj, self.owner_field, user) 476 477 def check_put(self, obj): 478 return self.validate_owner(g.user, obj) 479 480 def check_delete(self, obj): 481 return self.validate_owner(g.user, obj) 482 483 def save_object(self, instance, raw_data): 484 self.set_owner(instance, g.user) 485 return super(RestrictOwnerResource, self).save_object(instance, raw_data) 486 487 488class RestAPI(object): 489 def __init__(self, app, prefix='/api', default_auth=None, name='api'): 490 self.app = app 491 492 self._registry = {} 493 494 self.url_prefix = prefix 495 self.blueprint = self.get_blueprint(name) 496 497 self.default_auth = default_auth or Authentication() 498 499 def register(self, model, provider=RestResource, auth=None, allowed_methods=None): 500 self._registry[model] = provider(self, model, auth or self.default_auth, allowed_methods) 501 502 def unregister(self, model): 503 del(self._registry[model]) 504 505 def is_registered(self, model): 506 return self._registry.get(model) 507 508 def response_auth_failed(self): 509 return Response('Authentication failed', 401, { 510 'WWW-Authenticate': 'Basic realm="Login Required"' 511 }) 512 513 def auth_wrapper(self, func, provider): 514 @functools.wraps(func) 515 def inner(*args, **kwargs): 516 if not provider.authorize(): 517 return self.response_auth_failed() 518 return func(*args, **kwargs) 519 return inner 520 521 def get_blueprint(self, blueprint_name): 522 return Blueprint(blueprint_name, __name__) 523 524 def get_urls(self): 525 return () 526 527 def configure_routes(self): 528 for url, callback in self.get_urls(): 529 self.blueprint.route(url)(callback) 530 531 for provider in self._registry.values(): 532 api_name = provider.get_api_name() 533 for url, callback in provider.get_urls(): 534 full_url = '/%s%s' % (api_name, url) 535 self.blueprint.add_url_rule( 536 full_url, 537 '%s_%s' % (api_name, callback.__name__), 538 self.auth_wrapper(callback, provider), 539 methods=provider.allowed_methods, 540 ) 541 542 def register_blueprint(self, **kwargs): 543 self.app.register_blueprint(self.blueprint, url_prefix=self.url_prefix, **kwargs) 544 545 def setup(self): 546 self.configure_routes() 547 self.register_blueprint() 548