1import functools
2import operator
3try:
4    import simplejson as json
5except ImportError:
6    import json
7
8from flask import Blueprint
9from flask import Response
10from flask import abort
11from flask import g
12from flask import redirect
13from flask import request
14from flask import session
15from flask import url_for
16from peewee import *
17from peewee import DJANGO_MAP
18
19from flask_peewee.filters import make_field_tree
20from flask_peewee.serializer import Deserializer
21from flask_peewee.serializer import Serializer
22from flask_peewee.utils import PaginatedQuery
23from flask_peewee.utils import get_object_or_404
24from flask_peewee.utils import slugify
25from flask_peewee._compat import reduce
26
27
28class Authentication(object):
29    def __init__(self, protected_methods=None):
30        if protected_methods is None:
31            protected_methods = ['POST', 'PUT', 'DELETE']
32
33        self.protected_methods = protected_methods
34
35    def authorize(self):
36        if request.method in self.protected_methods:
37            return False
38
39        return True
40
41
42class APIKeyAuthentication(Authentication):
43    """
44    Requires a model that has at least two fields, "key" and "secret", which will
45    be searched for when authing a request.
46    """
47    key_field = 'key'
48    secret_field = 'secret'
49
50    def __init__(self, model, protected_methods=None):
51        super(APIKeyAuthentication, self).__init__(protected_methods)
52        self.model = model
53        self._key_field = model._meta.fields[self.key_field]
54        self._secret_field = model._meta.fields[self.secret_field]
55
56    def get_query(self):
57        return self.model.select()
58
59    def get_key(self, k, s):
60        try:
61            return self.get_query().where(
62                self._key_field==k,
63                self._secret_field==s
64            ).get()
65        except self.model.DoesNotExist:
66            pass
67
68    def get_key_secret(self):
69        for search in [request.args, request.headers, request.form]:
70            if 'key' in search and 'secret' in search:
71                return search['key'], search['secret']
72        return None, None
73
74    def authorize(self):
75        g.api_key = None
76
77        if request.method not in self.protected_methods:
78            return True
79
80        key, secret = self.get_key_secret()
81        if key or secret:
82            g.api_key = self.get_key(key, secret)
83
84        return g.api_key
85
86
87class UserAuthentication(Authentication):
88    def __init__(self, auth, protected_methods=None):
89        super(UserAuthentication, self).__init__(protected_methods)
90        self.auth = auth
91
92    def authorize(self):
93        g.user = None
94
95        if request.method not in self.protected_methods:
96            return True
97
98        basic_auth = request.authorization
99        if not basic_auth:
100            return False
101
102        g.user = self.auth.authenticate(basic_auth.username, basic_auth.password)
103        return g.user
104
105
106class AdminAuthentication(UserAuthentication):
107    def verify_user(self, user):
108        return user.admin
109
110    def authorize(self):
111        res = super(AdminAuthentication, self).authorize()
112
113        if res and g.user:
114            return self.verify_user(g.user)
115        return res
116
117
118class RestResource(object):
119    paginate_by = 20
120    value_transforms = {'False': False, 'false': False,
121                        'True': True, 'true': True,
122                        'None': None, 'none': None}
123
124    # serializing: dictionary of model -> field names to restrict output
125    fields = None
126    exclude = None
127
128    # exclude certian fields from being exposed as filters -- for related fields
129    # use "__" notation, e.g. user__password
130    filter_exclude = None
131    filter_fields = None
132    filter_recursive = True
133
134    # mapping of field name to resource class
135    include_resources = None
136
137    # delete behavior
138    delete_recursive = True
139
140    def __init__(self, rest_api, model, authentication, allowed_methods=None):
141        self.api = rest_api
142        self.model = model
143        self.pk = model._meta.primary_key
144
145        self.authentication = authentication
146        self.allowed_methods = allowed_methods or ['GET', 'POST', 'PUT', 'DELETE']
147
148        self._fields = {self.model: self.fields or self.model._meta.sorted_field_names}
149        if self.exclude:
150            self._exclude = {self.model: self.exclude}
151        else:
152            self._exclude = {}
153
154        self._filter_fields = self.filter_fields or list(self.model._meta.sorted_field_names)
155        self._filter_exclude = self.filter_exclude or []
156
157        self._resources = {}
158
159        # recurse into nested resources
160        if self.include_resources:
161            for field_name, resource in self.include_resources.items():
162                field_obj = self.model._meta.fields[field_name]
163                resource_obj = resource(self.api, field_obj.rel_model, self.authentication, self.allowed_methods)
164                self._resources[field_name] = resource_obj
165                self._fields.update(resource_obj._fields)
166                self._exclude.update(resource_obj._exclude)
167
168                self._filter_fields.extend(['%s__%s' % (field_name, ff) for ff in resource_obj._filter_fields])
169                self._filter_exclude.extend(['%s__%s' % (field_name, ff) for ff in resource_obj._filter_exclude])
170
171            self._include_foreign_keys = False
172        else:
173            self._include_foreign_keys = True
174
175        self._field_tree = make_field_tree(self.model, self._filter_fields, self._filter_exclude, self.filter_recursive)
176
177    def authorize(self):
178        return self.authentication.authorize()
179
180    def get_api_name(self):
181        return slugify(self.model.__name__)
182
183    def get_url_name(self, name):
184        return '%s.%s_%s' % (
185            self.api.blueprint.name,
186            self.get_api_name(),
187            name,
188        )
189
190    def get_query(self):
191        return self.model.select()
192
193    def process_query(self, query):
194        raw_filters = {}
195
196        # clean and normalize the request parameters
197        for key in request.args:
198            orig_key = key
199            if key.startswith('-'):
200                negated = True
201                key = key[1:]
202            else:
203                negated = False
204            if '__' in key:
205                expr, op = key.rsplit('__', 1)
206                if op not in DJANGO_MAP:
207                    expr = key
208                    op = 'eq'
209            else:
210                expr = key
211                op = 'eq'
212            raw_filters.setdefault(expr, [])
213            raw_filters[expr].append((op, request.args.getlist(orig_key), negated))
214
215        # do a breadth first search across the field tree created by filter_fields,
216        # searching for matching keys in the request parameters -- when found,
217        # filter the query accordingly
218        queue = [(self._field_tree, '')]
219        while queue:
220            node, prefix = queue.pop(0)
221            for field in node.fields:
222                filter_expr = '%s%s' % (prefix, field.name)
223                if filter_expr in raw_filters:
224                    for op, arg_list, negated in raw_filters[filter_expr]:
225                        clean_args = self.clean_arg_list(arg_list)
226                        query = self.apply_filter(query, filter_expr, op, clean_args, negated)
227
228            for child_prefix, child_node in node.children.items():
229                queue.append((child_node, prefix + child_prefix + '__'))
230
231        return query
232
233    def clean_arg_list(self, arg_list):
234        return [self.value_transforms.get(arg, arg) for arg in arg_list]
235
236    def apply_filter(self, query, expr, op, arg_list, negated):
237        query_expr = '%s__%s' % (expr, op)
238        constructor = lambda kwargs: negated and ~DQ(**kwargs) or DQ(**kwargs)
239        if op == 'in':
240            # in gives us a string format list '1,2,3,4'
241            # we have to turn it into a list before passing to
242            # the filter.
243            arg_list = [i.strip() for i in arg_list[0].split(',')]
244            return query.filter(constructor({query_expr: arg_list}))
245        elif len(arg_list) == 1:
246            return query.filter(constructor({query_expr: arg_list[0]}))
247        else:
248            query_clauses = [
249                constructor({query_expr: val}) for val in arg_list]
250            return query.filter(reduce(operator.or_, query_clauses))
251
252    def get_serializer(self):
253        return Serializer()
254
255    def get_deserializer(self):
256        return Deserializer()
257
258    def prepare_data(self, obj, data):
259        """
260        Hook for modifying outgoing data
261        """
262        return data
263
264    def serialize_object(self, obj):
265        s = self.get_serializer()
266        return self.prepare_data(
267            obj, s.serialize_object(obj, self._fields, self._exclude)
268        )
269
270    def serialize_query(self, query):
271        s = self.get_serializer()
272        return [
273            self.prepare_data(obj, s.serialize_object(obj, self._fields, self._exclude)) \
274                for obj in query
275        ]
276
277    def deserialize_object(self, data, instance):
278        d = self.get_deserializer()
279        return d.deserialize_object(instance, data)
280
281    def response_forbidden(self):
282        return Response('Forbidden', 403)
283
284    def response_bad_method(self):
285        return Response('Unsupported method "%s"' % (request.method), 405)
286
287    def response_bad_request(self):
288        return Response('Bad request', 400)
289
290    def response(self, data):
291        return Response(json.dumps(data), mimetype='application/json')
292
293    def require_method(self, func, methods):
294        @functools.wraps(func)
295        def inner(*args, **kwargs):
296            if request.method not in methods:
297                return self.response_bad_method()
298            return func(*args, **kwargs)
299        return inner
300
301    def get_urls(self):
302        return (
303            ('/', self.require_method(self.api_list, ['GET', 'POST'])),
304            ('/<pk>/', self.require_method(self.api_detail, ['GET', 'POST', 'PUT', 'DELETE'])),
305            ('/<pk>/delete/', self.require_method(self.post_delete, ['POST', 'DELETE'])),
306        )
307
308    def check_get(self, obj=None):
309        return True
310
311    def check_post(self, obj=None):
312        return True
313
314    def check_put(self, obj):
315        return True
316
317    def check_delete(self, obj):
318        return True
319
320    def save_object(self, instance, raw_data):
321        instance.save()
322        return instance
323
324    def api_list(self):
325        if not getattr(self, 'check_%s' % request.method.lower())():
326            return self.response_forbidden()
327
328        if request.method == 'GET':
329            return self.object_list()
330        elif request.method == 'POST':
331            return self.create()
332
333    def api_detail(self, pk, method=None):
334        obj = get_object_or_404(self.get_query(), self.pk==pk)
335
336        method = method or request.method
337
338        if not getattr(self, 'check_%s' % method.lower())(obj):
339            return self.response_forbidden()
340
341        if method == 'GET':
342            return self.object_detail(obj)
343        elif method in ('PUT', 'POST'):
344            return self.edit(obj)
345        elif method == 'DELETE':
346            return self.delete(obj)
347
348    def post_delete(self, pk):
349        return self.api_detail(pk, 'DELETE')
350
351    def apply_ordering(self, query):
352        ordering = request.args.get('ordering') or ''
353        if ordering:
354            desc, column = ordering.startswith('-'), ordering.lstrip('-')
355            if column in self.model._meta.fields:
356                field = self.model._meta.fields[column]
357                query = query.order_by(field.asc() if not desc else field.desc())
358
359        return query
360
361    def get_request_metadata(self, paginated_query):
362        var = paginated_query.page_var
363        request_arguments = request.args.copy()
364
365        current_page = paginated_query.get_page()
366        next = previous = ''
367
368        if current_page > 1:
369            request_arguments[var] = current_page - 1
370            previous = url_for(self.get_url_name('api_list'), **request_arguments)
371        if current_page < paginated_query.get_pages():
372            request_arguments[var] = current_page + 1
373            next = url_for(self.get_url_name('api_list'), **request_arguments)
374
375        return {
376            'model': self.get_api_name(),
377            'page': current_page,
378            'previous': previous,
379            'next': next,
380        }
381
382    def get_paginate_by(self):
383        try:
384            paginate_by = int(request.args.get('limit', self.paginate_by))
385        except ValueError:
386            paginate_by = self.paginate_by
387        else:
388            if self.paginate_by:
389                paginate_by = min(paginate_by, self.paginate_by) # restrict
390        return paginate_by
391
392    def paginated_object_list(self, filtered_query):
393        paginate_by = self.get_paginate_by()
394        pq = PaginatedQuery(filtered_query, paginate_by)
395        meta_data = self.get_request_metadata(pq)
396
397        query_dict = self.serialize_query(pq.get_list())
398
399        return self.response({
400            'meta': meta_data,
401            'objects': query_dict,
402        })
403
404    def object_list(self):
405        query = self.get_query()
406        query = self.apply_ordering(query)
407
408        # process any filters
409        query = self.process_query(query)
410
411        if self.paginate_by or 'limit' in request.args:
412            return self.paginated_object_list(query)
413
414        return self.response(self.serialize_query(query))
415
416    def object_detail(self, obj):
417        return self.response(self.serialize_object(obj))
418
419    def save_related_objects(self, instance, data):
420        for k, v in data.items():
421            if k in self._resources and isinstance(v, dict):
422                rel_resource = self._resources[k]
423                rel_obj, rel_models = rel_resource.deserialize_object(v, getattr(instance, k))
424                rel_resource.save_related_objects(rel_obj, v)
425                setattr(instance, k, rel_resource.save_object(rel_obj, v))
426
427    def read_request_data(self):
428        if request.data:
429            return json.loads(request.data.decode('utf-8'))
430        elif request.form.get('data'):
431            return json.loads(request.form['data'])
432        else:
433            return dict(request.form)
434
435    def create(self):
436        try:
437            data = self.read_request_data()
438        except ValueError:
439            return self.response_bad_request()
440
441        obj, models = self.deserialize_object(data, self.model())
442
443        self.save_related_objects(obj, data)
444        obj = self.save_object(obj, data)
445
446        return self.response(self.serialize_object(obj))
447
448    def edit(self, obj):
449        try:
450            data = self.read_request_data()
451        except ValueError:
452            return self.response_bad_request()
453
454        obj, models = self.deserialize_object(data, obj)
455
456        self.save_related_objects(obj, data)
457        obj = self.save_object(obj, data)
458
459        return self.response(self.serialize_object(obj))
460
461    def delete(self, obj):
462        res = obj.delete_instance(recursive=self.delete_recursive)
463        return self.response({'deleted': res})
464
465
466class RestrictOwnerResource(RestResource):
467    # restrict PUT/DELETE to owner of an object, likewise apply owner to any
468    # incoming POSTs
469    owner_field = 'user'
470
471    def validate_owner(self, user, obj):
472        return user == getattr(obj, self.owner_field)
473
474    def set_owner(self, obj, user):
475        setattr(obj, self.owner_field, user)
476
477    def check_put(self, obj):
478        return self.validate_owner(g.user, obj)
479
480    def check_delete(self, obj):
481        return self.validate_owner(g.user, obj)
482
483    def save_object(self, instance, raw_data):
484        self.set_owner(instance, g.user)
485        return super(RestrictOwnerResource, self).save_object(instance, raw_data)
486
487
488class RestAPI(object):
489    def __init__(self, app, prefix='/api', default_auth=None, name='api'):
490        self.app = app
491
492        self._registry = {}
493
494        self.url_prefix = prefix
495        self.blueprint = self.get_blueprint(name)
496
497        self.default_auth = default_auth or Authentication()
498
499    def register(self, model, provider=RestResource, auth=None, allowed_methods=None):
500        self._registry[model] = provider(self, model, auth or self.default_auth, allowed_methods)
501
502    def unregister(self, model):
503        del(self._registry[model])
504
505    def is_registered(self, model):
506        return self._registry.get(model)
507
508    def response_auth_failed(self):
509        return Response('Authentication failed', 401, {
510            'WWW-Authenticate': 'Basic realm="Login Required"'
511        })
512
513    def auth_wrapper(self, func, provider):
514        @functools.wraps(func)
515        def inner(*args, **kwargs):
516            if not provider.authorize():
517                return self.response_auth_failed()
518            return func(*args, **kwargs)
519        return inner
520
521    def get_blueprint(self, blueprint_name):
522        return Blueprint(blueprint_name, __name__)
523
524    def get_urls(self):
525        return ()
526
527    def configure_routes(self):
528        for url, callback in self.get_urls():
529            self.blueprint.route(url)(callback)
530
531        for provider in self._registry.values():
532            api_name = provider.get_api_name()
533            for url, callback in provider.get_urls():
534                full_url = '/%s%s' % (api_name, url)
535                self.blueprint.add_url_rule(
536                    full_url,
537                    '%s_%s' % (api_name, callback.__name__),
538                    self.auth_wrapper(callback, provider),
539                    methods=provider.allowed_methods,
540                )
541
542    def register_blueprint(self, **kwargs):
543        self.app.register_blueprint(self.blueprint, url_prefix=self.url_prefix, **kwargs)
544
545    def setup(self):
546        self.configure_routes()
547        self.register_blueprint()
548