1from collections import OrderedDict
2
3from django.conf import settings
4from django.core.exceptions import FieldDoesNotExist
5from django.http import Http404
6from django.shortcuts import redirect
7from django.urls import path, reverse
8from modelcluster.fields import ParentalKey
9from rest_framework import status
10from rest_framework.renderers import BrowsableAPIRenderer, JSONRenderer
11from rest_framework.response import Response
12from rest_framework.viewsets import GenericViewSet
13
14from wagtail.api import APIField
15from wagtail.core.models import Page, Site
16
17from .filters import (
18    AncestorOfFilter, ChildOfFilter, DescendantOfFilter, FieldsFilter, LocaleFilter, OrderingFilter,
19    SearchFilter, TranslationOfFilter)
20from .pagination import WagtailPagination
21from .serializers import BaseSerializer, PageSerializer, get_serializer_class
22from .utils import (
23    BadRequestError, get_object_detail_url, page_models_from_string, parse_fields_parameter)
24
25
26class BaseAPIViewSet(GenericViewSet):
27    renderer_classes = [JSONRenderer, BrowsableAPIRenderer]
28
29    pagination_class = WagtailPagination
30    base_serializer_class = BaseSerializer
31    filter_backends = []
32    model = None  # Set on subclass
33
34    known_query_parameters = frozenset([
35        'limit',
36        'offset',
37        'fields',
38        'order',
39        'search',
40        'search_operator',
41
42        # Used by jQuery for cache-busting. See #1671
43        '_',
44
45        # Required by BrowsableAPIRenderer
46        'format',
47    ])
48    body_fields = ['id']
49    meta_fields = ['type', 'detail_url']
50    listing_default_fields = ['id', 'type', 'detail_url']
51    nested_default_fields = ['id', 'type', 'detail_url']
52    detail_only_fields = []
53    name = None  # Set on subclass.
54
55    def __init__(self, *args, **kwargs):
56        super().__init__(*args, **kwargs)
57
58        # seen_types is a mapping of type name strings (format: "app_label.ModelName")
59        # to model classes. When an object is serialised in the API, its model
60        # is added to this mapping. This is used by the Admin API which appends a
61        # summary of the used types to the response.
62        self.seen_types = OrderedDict()
63
64    def get_queryset(self):
65        return self.model.objects.all().order_by('id')
66
67    def listing_view(self, request):
68        queryset = self.get_queryset()
69        self.check_query_parameters(queryset)
70        queryset = self.filter_queryset(queryset)
71        queryset = self.paginate_queryset(queryset)
72        serializer = self.get_serializer(queryset, many=True)
73        return self.get_paginated_response(serializer.data)
74
75    def detail_view(self, request, pk):
76        instance = self.get_object()
77        serializer = self.get_serializer(instance)
78        return Response(serializer.data)
79
80    def find_view(self, request):
81        queryset = self.get_queryset()
82
83        try:
84            obj = self.find_object(queryset, request)
85
86            if obj is None:
87                raise self.model.DoesNotExist
88
89        except self.model.DoesNotExist:
90            raise Http404("not found")
91
92        # Generate redirect
93        url = get_object_detail_url(self.request.wagtailapi_router, request, self.model, obj.pk)
94
95        if url is None:
96            # Shouldn't happen unless this endpoint isn't actually installed in the router
97            raise Exception("Cannot generate URL to detail view. Is '{}' installed in the API router?".format(self.__class__.__name__))
98
99        return redirect(url)
100
101    def find_object(self, queryset, request):
102        """
103        Override this to implement more find methods.
104        """
105        if 'id' in request.GET:
106            return queryset.get(id=request.GET['id'])
107
108    def handle_exception(self, exc):
109        if isinstance(exc, Http404):
110            data = {'message': str(exc)}
111            return Response(data, status=status.HTTP_404_NOT_FOUND)
112        elif isinstance(exc, BadRequestError):
113            data = {'message': str(exc)}
114            return Response(data, status=status.HTTP_400_BAD_REQUEST)
115        return super().handle_exception(exc)
116
117    @classmethod
118    def _convert_api_fields(cls, fields):
119        return [field if isinstance(field, APIField) else APIField(field)
120                for field in fields]
121
122    @classmethod
123    def get_body_fields(cls, model):
124        return cls._convert_api_fields(cls.body_fields + list(getattr(model, 'api_fields', ())))
125
126    @classmethod
127    def get_body_fields_names(cls, model):
128        return [field.name for field in cls.get_body_fields(model)]
129
130    @classmethod
131    def get_meta_fields(cls, model):
132        return cls._convert_api_fields(cls.meta_fields + list(getattr(model, 'api_meta_fields', ())))
133
134    @classmethod
135    def get_meta_fields_names(cls, model):
136        return [field.name for field in cls.get_meta_fields(model)]
137
138    @classmethod
139    def get_field_serializer_overrides(cls, model):
140        return {field.name: field.serializer
141                for field in cls.get_body_fields(model) + cls.get_meta_fields(model)
142                if field.serializer is not None}
143
144    @classmethod
145    def get_available_fields(cls, model, db_fields_only=False):
146        """
147        Returns a list of all the fields that can be used in the API for the
148        specified model class.
149
150        Setting db_fields_only to True will remove all fields that do not have
151        an underlying column in the database (eg, type/detail_url and any custom
152        fields that are callables)
153        """
154        fields = cls.get_body_fields_names(model) + cls.get_meta_fields_names(model)
155
156        if db_fields_only:
157            # Get list of available database fields then remove any fields in our
158            # list that isn't a database field
159            database_fields = set()
160            for field in model._meta.get_fields():
161                database_fields.add(field.name)
162
163                if hasattr(field, 'attname'):
164                    database_fields.add(field.attname)
165
166            fields = [field for field in fields if field in database_fields]
167
168        return fields
169
170    @classmethod
171    def get_detail_default_fields(cls, model):
172        return cls.get_available_fields(model)
173
174    @classmethod
175    def get_listing_default_fields(cls, model):
176        return cls.listing_default_fields[:]
177
178    @classmethod
179    def get_nested_default_fields(cls, model):
180        return cls.nested_default_fields[:]
181
182    def check_query_parameters(self, queryset):
183        """
184        Ensure that only valid query paramters are included in the URL.
185        """
186        query_parameters = set(self.request.GET.keys())
187
188        # All query paramters must be either a database field or an operation
189        allowed_query_parameters = set(self.get_available_fields(queryset.model, db_fields_only=True)).union(self.known_query_parameters)
190        unknown_parameters = query_parameters - allowed_query_parameters
191        if unknown_parameters:
192            raise BadRequestError("query parameter is not an operation or a recognised field: %s" % ', '.join(sorted(unknown_parameters)))
193
194    @classmethod
195    def _get_serializer_class(cls, router, model, fields_config, show_details=False, nested=False):
196        # Get all available fields
197        body_fields = cls.get_body_fields_names(model)
198        meta_fields = cls.get_meta_fields_names(model)
199        all_fields = body_fields + meta_fields
200
201        # Remove any duplicates
202        all_fields = list(OrderedDict.fromkeys(all_fields))
203
204        if not show_details:
205            # Remove detail only fields
206            for field in cls.detail_only_fields:
207                try:
208                    all_fields.remove(field)
209                except KeyError:
210                    pass
211
212        # Get list of configured fields
213        if show_details:
214            fields = set(cls.get_detail_default_fields(model))
215        elif nested:
216            fields = set(cls.get_nested_default_fields(model))
217        else:
218            fields = set(cls.get_listing_default_fields(model))
219
220        # If first field is '*' start with all fields
221        # If first field is '_' start with no fields
222        if fields_config and fields_config[0][0] == '*':
223            fields = set(all_fields)
224            fields_config = fields_config[1:]
225        elif fields_config and fields_config[0][0] == '_':
226            fields = set()
227            fields_config = fields_config[1:]
228
229        mentioned_fields = set()
230        sub_fields = {}
231
232        for field_name, negated, field_sub_fields in fields_config:
233            if negated:
234                try:
235                    fields.remove(field_name)
236                except KeyError:
237                    pass
238            else:
239                fields.add(field_name)
240                if field_sub_fields:
241                    sub_fields[field_name] = field_sub_fields
242
243            mentioned_fields.add(field_name)
244
245        unknown_fields = mentioned_fields - set(all_fields)
246
247        if unknown_fields:
248            raise BadRequestError("unknown fields: %s" % ', '.join(sorted(unknown_fields)))
249
250        # Build nested serialisers
251        child_serializer_classes = {}
252
253        for field_name in fields:
254            try:
255                django_field = model._meta.get_field(field_name)
256            except FieldDoesNotExist:
257                django_field = None
258
259            if django_field and django_field.is_relation:
260                child_sub_fields = sub_fields.get(field_name, [])
261
262                # Inline (aka "child") models should display all fields by default
263                if isinstance(getattr(django_field, 'field', None), ParentalKey):
264                    if not child_sub_fields or child_sub_fields[0][0] not in ['*', '_']:
265                        child_sub_fields = list(child_sub_fields)
266                        child_sub_fields.insert(0, ('*', False, None))
267
268                # Get a serializer class for the related object
269                child_model = django_field.related_model
270                child_endpoint_class = router.get_model_endpoint(child_model)
271                child_endpoint_class = child_endpoint_class[1] if child_endpoint_class else BaseAPIViewSet
272                child_serializer_classes[field_name] = child_endpoint_class._get_serializer_class(router, child_model, child_sub_fields, nested=True)
273
274            else:
275                if field_name in sub_fields:
276                    # Sub fields were given for a non-related field
277                    raise BadRequestError("'%s' does not support nested fields" % field_name)
278
279        # Reorder fields so it matches the order of all_fields
280        fields = [field for field in all_fields if field in fields]
281
282        field_serializer_overrides = {field[0]: field[1] for field in cls.get_field_serializer_overrides(model).items() if field[0] in fields}
283        return get_serializer_class(
284            model,
285            fields,
286            meta_fields=meta_fields,
287            field_serializer_overrides=field_serializer_overrides,
288            child_serializer_classes=child_serializer_classes,
289            base=cls.base_serializer_class
290        )
291
292    def get_serializer_class(self):
293        request = self.request
294
295        # Get model
296        if self.action == 'listing_view':
297            model = self.get_queryset().model
298        else:
299            model = type(self.get_object())
300
301        # Fields
302        if 'fields' in request.GET:
303            try:
304                fields_config = parse_fields_parameter(request.GET['fields'])
305            except ValueError as e:
306                raise BadRequestError("fields error: %s" % str(e))
307        else:
308            # Use default fields
309            fields_config = []
310
311        # Allow "detail_only" (eg parent) fields on detail view
312        if self.action == 'listing_view':
313            show_details = False
314        else:
315            show_details = True
316
317        return self._get_serializer_class(self.request.wagtailapi_router, model, fields_config, show_details=show_details)
318
319    def get_serializer_context(self):
320        """
321        The serialization context differs between listing and detail views.
322        """
323        return {
324            'request': self.request,
325            'view': self,
326            'router': self.request.wagtailapi_router
327        }
328
329    def get_renderer_context(self):
330        context = super().get_renderer_context()
331        context['indent'] = 4
332        return context
333
334    @classmethod
335    def get_urlpatterns(cls):
336        """
337        This returns a list of URL patterns for the endpoint
338        """
339        return [
340            path('', cls.as_view({'get': 'listing_view'}), name='listing'),
341            path('<int:pk>/', cls.as_view({'get': 'detail_view'}), name='detail'),
342            path('find/', cls.as_view({'get': 'find_view'}), name='find'),
343        ]
344
345    @classmethod
346    def get_model_listing_urlpath(cls, model, namespace=''):
347        if namespace:
348            url_name = namespace + ':listing'
349        else:
350            url_name = 'listing'
351
352        return reverse(url_name)
353
354    @classmethod
355    def get_object_detail_urlpath(cls, model, pk, namespace=''):
356        if namespace:
357            url_name = namespace + ':detail'
358        else:
359            url_name = 'detail'
360
361        return reverse(url_name, args=(pk, ))
362
363
364class PagesAPIViewSet(BaseAPIViewSet):
365    base_serializer_class = PageSerializer
366    filter_backends = [
367        FieldsFilter,
368        ChildOfFilter,
369        AncestorOfFilter,
370        DescendantOfFilter,
371        OrderingFilter,
372        TranslationOfFilter,
373        LocaleFilter,
374        SearchFilter,  # needs to be last, as SearchResults querysets cannot be filtered further
375    ]
376    known_query_parameters = BaseAPIViewSet.known_query_parameters.union([
377        'type',
378        'child_of',
379        'ancestor_of',
380        'descendant_of',
381        'translation_of',
382        'locale',
383    ])
384    body_fields = BaseAPIViewSet.body_fields + [
385        'title',
386    ]
387    meta_fields = BaseAPIViewSet.meta_fields + [
388        'html_url',
389        'slug',
390        'show_in_menus',
391        'seo_title',
392        'search_description',
393        'first_published_at',
394        'parent',
395        'locale',
396    ]
397    listing_default_fields = BaseAPIViewSet.listing_default_fields + [
398        'title',
399        'html_url',
400        'slug',
401        'first_published_at',
402    ]
403    nested_default_fields = BaseAPIViewSet.nested_default_fields + [
404        'title',
405    ]
406    detail_only_fields = ['parent']
407    name = 'pages'
408    model = Page
409
410    @classmethod
411    def get_detail_default_fields(cls, model):
412        detail_default_fields = super().get_detail_default_fields(model)
413
414        # When i18n is disabled, remove "locale" from default fields
415        if not getattr(settings, 'WAGTAIL_I18N_ENABLED', False):
416            detail_default_fields.remove('locale')
417
418        return detail_default_fields
419
420    @classmethod
421    def get_listing_default_fields(cls, model):
422        listing_default_fields = super().get_listing_default_fields(model)
423
424        # When i18n is enabled, add "locale" to default fields
425        if getattr(settings, 'WAGTAIL_I18N_ENABLED', False):
426            listing_default_fields.append('locale')
427
428        return listing_default_fields
429
430    def get_root_page(self):
431        """
432        Returns the page that is used when the `&child_of=root` filter is used.
433        """
434        return Site.find_for_request(self.request).root_page
435
436    def get_base_queryset(self):
437        """
438        Returns a queryset containing all pages that can be seen by this user.
439
440        This is used as the base for get_queryset and is also used to find the
441        parent pages when using the child_of and descendant_of filters as well.
442        """
443        # Get live pages that are not in a private section
444        queryset = Page.objects.all().public().live()
445
446        # Filter by site
447        site = Site.find_for_request(self.request)
448        if site:
449            base_queryset = queryset
450            queryset = base_queryset.descendant_of(site.root_page, inclusive=True)
451
452            # If internationalisation is enabled, include pages from other language trees
453            if getattr(settings, 'WAGTAIL_I18N_ENABLED', False):
454                for translation in site.root_page.get_translations():
455                    queryset |= base_queryset.descendant_of(translation, inclusive=True)
456
457        else:
458            # No sites configured
459            queryset = queryset.none()
460
461        return queryset
462
463    def get_queryset(self):
464        request = self.request
465
466        # Allow pages to be filtered to a specific type
467        try:
468            models = page_models_from_string(request.GET.get('type', 'wagtailcore.Page'))
469        except (LookupError, ValueError):
470            raise BadRequestError("type doesn't exist")
471
472        if not models:
473            return self.get_base_queryset()
474
475        elif len(models) == 1:
476            # If a single page type has been specified, swap out the Page-based queryset for one based on
477            # the specific page model so that we can filter on any custom APIFields defined on that model
478            return models[0].objects.filter(id__in=self.get_base_queryset().values_list('id', flat=True))
479
480        else:  # len(models) > 1
481            return self.get_base_queryset().type(*models)
482
483    def get_object(self):
484        base = super().get_object()
485        return base.specific
486
487    def find_object(self, queryset, request):
488        site = Site.find_for_request(request)
489        if 'html_path' in request.GET and site is not None:
490            path = request.GET['html_path']
491            path_components = [component for component in path.split('/') if component]
492
493            try:
494                page, _, _ = site.root_page.specific.route(request, path_components)
495            except Http404:
496                return
497
498            if queryset.filter(id=page.id).exists():
499                return page
500
501        return super().find_object(queryset, request)
502
503    def get_serializer_context(self):
504        """
505        The serialization context differs between listing and detail views.
506        """
507        context = super().get_serializer_context()
508        context['base_queryset'] = self.get_base_queryset()
509        return context
510