1from collections import OrderedDict 2 3from django.conf import settings 4from django.core.exceptions import FieldDoesNotExist 5from django.http import Http404 6from django.shortcuts import redirect 7from django.urls import path, reverse 8from modelcluster.fields import ParentalKey 9from rest_framework import status 10from rest_framework.renderers import BrowsableAPIRenderer, JSONRenderer 11from rest_framework.response import Response 12from rest_framework.viewsets import GenericViewSet 13 14from wagtail.api import APIField 15from wagtail.core.models import Page, Site 16 17from .filters import ( 18 AncestorOfFilter, ChildOfFilter, DescendantOfFilter, FieldsFilter, LocaleFilter, OrderingFilter, 19 SearchFilter, TranslationOfFilter) 20from .pagination import WagtailPagination 21from .serializers import BaseSerializer, PageSerializer, get_serializer_class 22from .utils import ( 23 BadRequestError, get_object_detail_url, page_models_from_string, parse_fields_parameter) 24 25 26class BaseAPIViewSet(GenericViewSet): 27 renderer_classes = [JSONRenderer, BrowsableAPIRenderer] 28 29 pagination_class = WagtailPagination 30 base_serializer_class = BaseSerializer 31 filter_backends = [] 32 model = None # Set on subclass 33 34 known_query_parameters = frozenset([ 35 'limit', 36 'offset', 37 'fields', 38 'order', 39 'search', 40 'search_operator', 41 42 # Used by jQuery for cache-busting. See #1671 43 '_', 44 45 # Required by BrowsableAPIRenderer 46 'format', 47 ]) 48 body_fields = ['id'] 49 meta_fields = ['type', 'detail_url'] 50 listing_default_fields = ['id', 'type', 'detail_url'] 51 nested_default_fields = ['id', 'type', 'detail_url'] 52 detail_only_fields = [] 53 name = None # Set on subclass. 54 55 def __init__(self, *args, **kwargs): 56 super().__init__(*args, **kwargs) 57 58 # seen_types is a mapping of type name strings (format: "app_label.ModelName") 59 # to model classes. When an object is serialised in the API, its model 60 # is added to this mapping. This is used by the Admin API which appends a 61 # summary of the used types to the response. 62 self.seen_types = OrderedDict() 63 64 def get_queryset(self): 65 return self.model.objects.all().order_by('id') 66 67 def listing_view(self, request): 68 queryset = self.get_queryset() 69 self.check_query_parameters(queryset) 70 queryset = self.filter_queryset(queryset) 71 queryset = self.paginate_queryset(queryset) 72 serializer = self.get_serializer(queryset, many=True) 73 return self.get_paginated_response(serializer.data) 74 75 def detail_view(self, request, pk): 76 instance = self.get_object() 77 serializer = self.get_serializer(instance) 78 return Response(serializer.data) 79 80 def find_view(self, request): 81 queryset = self.get_queryset() 82 83 try: 84 obj = self.find_object(queryset, request) 85 86 if obj is None: 87 raise self.model.DoesNotExist 88 89 except self.model.DoesNotExist: 90 raise Http404("not found") 91 92 # Generate redirect 93 url = get_object_detail_url(self.request.wagtailapi_router, request, self.model, obj.pk) 94 95 if url is None: 96 # Shouldn't happen unless this endpoint isn't actually installed in the router 97 raise Exception("Cannot generate URL to detail view. Is '{}' installed in the API router?".format(self.__class__.__name__)) 98 99 return redirect(url) 100 101 def find_object(self, queryset, request): 102 """ 103 Override this to implement more find methods. 104 """ 105 if 'id' in request.GET: 106 return queryset.get(id=request.GET['id']) 107 108 def handle_exception(self, exc): 109 if isinstance(exc, Http404): 110 data = {'message': str(exc)} 111 return Response(data, status=status.HTTP_404_NOT_FOUND) 112 elif isinstance(exc, BadRequestError): 113 data = {'message': str(exc)} 114 return Response(data, status=status.HTTP_400_BAD_REQUEST) 115 return super().handle_exception(exc) 116 117 @classmethod 118 def _convert_api_fields(cls, fields): 119 return [field if isinstance(field, APIField) else APIField(field) 120 for field in fields] 121 122 @classmethod 123 def get_body_fields(cls, model): 124 return cls._convert_api_fields(cls.body_fields + list(getattr(model, 'api_fields', ()))) 125 126 @classmethod 127 def get_body_fields_names(cls, model): 128 return [field.name for field in cls.get_body_fields(model)] 129 130 @classmethod 131 def get_meta_fields(cls, model): 132 return cls._convert_api_fields(cls.meta_fields + list(getattr(model, 'api_meta_fields', ()))) 133 134 @classmethod 135 def get_meta_fields_names(cls, model): 136 return [field.name for field in cls.get_meta_fields(model)] 137 138 @classmethod 139 def get_field_serializer_overrides(cls, model): 140 return {field.name: field.serializer 141 for field in cls.get_body_fields(model) + cls.get_meta_fields(model) 142 if field.serializer is not None} 143 144 @classmethod 145 def get_available_fields(cls, model, db_fields_only=False): 146 """ 147 Returns a list of all the fields that can be used in the API for the 148 specified model class. 149 150 Setting db_fields_only to True will remove all fields that do not have 151 an underlying column in the database (eg, type/detail_url and any custom 152 fields that are callables) 153 """ 154 fields = cls.get_body_fields_names(model) + cls.get_meta_fields_names(model) 155 156 if db_fields_only: 157 # Get list of available database fields then remove any fields in our 158 # list that isn't a database field 159 database_fields = set() 160 for field in model._meta.get_fields(): 161 database_fields.add(field.name) 162 163 if hasattr(field, 'attname'): 164 database_fields.add(field.attname) 165 166 fields = [field for field in fields if field in database_fields] 167 168 return fields 169 170 @classmethod 171 def get_detail_default_fields(cls, model): 172 return cls.get_available_fields(model) 173 174 @classmethod 175 def get_listing_default_fields(cls, model): 176 return cls.listing_default_fields[:] 177 178 @classmethod 179 def get_nested_default_fields(cls, model): 180 return cls.nested_default_fields[:] 181 182 def check_query_parameters(self, queryset): 183 """ 184 Ensure that only valid query paramters are included in the URL. 185 """ 186 query_parameters = set(self.request.GET.keys()) 187 188 # All query paramters must be either a database field or an operation 189 allowed_query_parameters = set(self.get_available_fields(queryset.model, db_fields_only=True)).union(self.known_query_parameters) 190 unknown_parameters = query_parameters - allowed_query_parameters 191 if unknown_parameters: 192 raise BadRequestError("query parameter is not an operation or a recognised field: %s" % ', '.join(sorted(unknown_parameters))) 193 194 @classmethod 195 def _get_serializer_class(cls, router, model, fields_config, show_details=False, nested=False): 196 # Get all available fields 197 body_fields = cls.get_body_fields_names(model) 198 meta_fields = cls.get_meta_fields_names(model) 199 all_fields = body_fields + meta_fields 200 201 # Remove any duplicates 202 all_fields = list(OrderedDict.fromkeys(all_fields)) 203 204 if not show_details: 205 # Remove detail only fields 206 for field in cls.detail_only_fields: 207 try: 208 all_fields.remove(field) 209 except KeyError: 210 pass 211 212 # Get list of configured fields 213 if show_details: 214 fields = set(cls.get_detail_default_fields(model)) 215 elif nested: 216 fields = set(cls.get_nested_default_fields(model)) 217 else: 218 fields = set(cls.get_listing_default_fields(model)) 219 220 # If first field is '*' start with all fields 221 # If first field is '_' start with no fields 222 if fields_config and fields_config[0][0] == '*': 223 fields = set(all_fields) 224 fields_config = fields_config[1:] 225 elif fields_config and fields_config[0][0] == '_': 226 fields = set() 227 fields_config = fields_config[1:] 228 229 mentioned_fields = set() 230 sub_fields = {} 231 232 for field_name, negated, field_sub_fields in fields_config: 233 if negated: 234 try: 235 fields.remove(field_name) 236 except KeyError: 237 pass 238 else: 239 fields.add(field_name) 240 if field_sub_fields: 241 sub_fields[field_name] = field_sub_fields 242 243 mentioned_fields.add(field_name) 244 245 unknown_fields = mentioned_fields - set(all_fields) 246 247 if unknown_fields: 248 raise BadRequestError("unknown fields: %s" % ', '.join(sorted(unknown_fields))) 249 250 # Build nested serialisers 251 child_serializer_classes = {} 252 253 for field_name in fields: 254 try: 255 django_field = model._meta.get_field(field_name) 256 except FieldDoesNotExist: 257 django_field = None 258 259 if django_field and django_field.is_relation: 260 child_sub_fields = sub_fields.get(field_name, []) 261 262 # Inline (aka "child") models should display all fields by default 263 if isinstance(getattr(django_field, 'field', None), ParentalKey): 264 if not child_sub_fields or child_sub_fields[0][0] not in ['*', '_']: 265 child_sub_fields = list(child_sub_fields) 266 child_sub_fields.insert(0, ('*', False, None)) 267 268 # Get a serializer class for the related object 269 child_model = django_field.related_model 270 child_endpoint_class = router.get_model_endpoint(child_model) 271 child_endpoint_class = child_endpoint_class[1] if child_endpoint_class else BaseAPIViewSet 272 child_serializer_classes[field_name] = child_endpoint_class._get_serializer_class(router, child_model, child_sub_fields, nested=True) 273 274 else: 275 if field_name in sub_fields: 276 # Sub fields were given for a non-related field 277 raise BadRequestError("'%s' does not support nested fields" % field_name) 278 279 # Reorder fields so it matches the order of all_fields 280 fields = [field for field in all_fields if field in fields] 281 282 field_serializer_overrides = {field[0]: field[1] for field in cls.get_field_serializer_overrides(model).items() if field[0] in fields} 283 return get_serializer_class( 284 model, 285 fields, 286 meta_fields=meta_fields, 287 field_serializer_overrides=field_serializer_overrides, 288 child_serializer_classes=child_serializer_classes, 289 base=cls.base_serializer_class 290 ) 291 292 def get_serializer_class(self): 293 request = self.request 294 295 # Get model 296 if self.action == 'listing_view': 297 model = self.get_queryset().model 298 else: 299 model = type(self.get_object()) 300 301 # Fields 302 if 'fields' in request.GET: 303 try: 304 fields_config = parse_fields_parameter(request.GET['fields']) 305 except ValueError as e: 306 raise BadRequestError("fields error: %s" % str(e)) 307 else: 308 # Use default fields 309 fields_config = [] 310 311 # Allow "detail_only" (eg parent) fields on detail view 312 if self.action == 'listing_view': 313 show_details = False 314 else: 315 show_details = True 316 317 return self._get_serializer_class(self.request.wagtailapi_router, model, fields_config, show_details=show_details) 318 319 def get_serializer_context(self): 320 """ 321 The serialization context differs between listing and detail views. 322 """ 323 return { 324 'request': self.request, 325 'view': self, 326 'router': self.request.wagtailapi_router 327 } 328 329 def get_renderer_context(self): 330 context = super().get_renderer_context() 331 context['indent'] = 4 332 return context 333 334 @classmethod 335 def get_urlpatterns(cls): 336 """ 337 This returns a list of URL patterns for the endpoint 338 """ 339 return [ 340 path('', cls.as_view({'get': 'listing_view'}), name='listing'), 341 path('<int:pk>/', cls.as_view({'get': 'detail_view'}), name='detail'), 342 path('find/', cls.as_view({'get': 'find_view'}), name='find'), 343 ] 344 345 @classmethod 346 def get_model_listing_urlpath(cls, model, namespace=''): 347 if namespace: 348 url_name = namespace + ':listing' 349 else: 350 url_name = 'listing' 351 352 return reverse(url_name) 353 354 @classmethod 355 def get_object_detail_urlpath(cls, model, pk, namespace=''): 356 if namespace: 357 url_name = namespace + ':detail' 358 else: 359 url_name = 'detail' 360 361 return reverse(url_name, args=(pk, )) 362 363 364class PagesAPIViewSet(BaseAPIViewSet): 365 base_serializer_class = PageSerializer 366 filter_backends = [ 367 FieldsFilter, 368 ChildOfFilter, 369 AncestorOfFilter, 370 DescendantOfFilter, 371 OrderingFilter, 372 TranslationOfFilter, 373 LocaleFilter, 374 SearchFilter, # needs to be last, as SearchResults querysets cannot be filtered further 375 ] 376 known_query_parameters = BaseAPIViewSet.known_query_parameters.union([ 377 'type', 378 'child_of', 379 'ancestor_of', 380 'descendant_of', 381 'translation_of', 382 'locale', 383 ]) 384 body_fields = BaseAPIViewSet.body_fields + [ 385 'title', 386 ] 387 meta_fields = BaseAPIViewSet.meta_fields + [ 388 'html_url', 389 'slug', 390 'show_in_menus', 391 'seo_title', 392 'search_description', 393 'first_published_at', 394 'parent', 395 'locale', 396 ] 397 listing_default_fields = BaseAPIViewSet.listing_default_fields + [ 398 'title', 399 'html_url', 400 'slug', 401 'first_published_at', 402 ] 403 nested_default_fields = BaseAPIViewSet.nested_default_fields + [ 404 'title', 405 ] 406 detail_only_fields = ['parent'] 407 name = 'pages' 408 model = Page 409 410 @classmethod 411 def get_detail_default_fields(cls, model): 412 detail_default_fields = super().get_detail_default_fields(model) 413 414 # When i18n is disabled, remove "locale" from default fields 415 if not getattr(settings, 'WAGTAIL_I18N_ENABLED', False): 416 detail_default_fields.remove('locale') 417 418 return detail_default_fields 419 420 @classmethod 421 def get_listing_default_fields(cls, model): 422 listing_default_fields = super().get_listing_default_fields(model) 423 424 # When i18n is enabled, add "locale" to default fields 425 if getattr(settings, 'WAGTAIL_I18N_ENABLED', False): 426 listing_default_fields.append('locale') 427 428 return listing_default_fields 429 430 def get_root_page(self): 431 """ 432 Returns the page that is used when the `&child_of=root` filter is used. 433 """ 434 return Site.find_for_request(self.request).root_page 435 436 def get_base_queryset(self): 437 """ 438 Returns a queryset containing all pages that can be seen by this user. 439 440 This is used as the base for get_queryset and is also used to find the 441 parent pages when using the child_of and descendant_of filters as well. 442 """ 443 # Get live pages that are not in a private section 444 queryset = Page.objects.all().public().live() 445 446 # Filter by site 447 site = Site.find_for_request(self.request) 448 if site: 449 base_queryset = queryset 450 queryset = base_queryset.descendant_of(site.root_page, inclusive=True) 451 452 # If internationalisation is enabled, include pages from other language trees 453 if getattr(settings, 'WAGTAIL_I18N_ENABLED', False): 454 for translation in site.root_page.get_translations(): 455 queryset |= base_queryset.descendant_of(translation, inclusive=True) 456 457 else: 458 # No sites configured 459 queryset = queryset.none() 460 461 return queryset 462 463 def get_queryset(self): 464 request = self.request 465 466 # Allow pages to be filtered to a specific type 467 try: 468 models = page_models_from_string(request.GET.get('type', 'wagtailcore.Page')) 469 except (LookupError, ValueError): 470 raise BadRequestError("type doesn't exist") 471 472 if not models: 473 return self.get_base_queryset() 474 475 elif len(models) == 1: 476 # If a single page type has been specified, swap out the Page-based queryset for one based on 477 # the specific page model so that we can filter on any custom APIFields defined on that model 478 return models[0].objects.filter(id__in=self.get_base_queryset().values_list('id', flat=True)) 479 480 else: # len(models) > 1 481 return self.get_base_queryset().type(*models) 482 483 def get_object(self): 484 base = super().get_object() 485 return base.specific 486 487 def find_object(self, queryset, request): 488 site = Site.find_for_request(request) 489 if 'html_path' in request.GET and site is not None: 490 path = request.GET['html_path'] 491 path_components = [component for component in path.split('/') if component] 492 493 try: 494 page, _, _ = site.root_page.specific.route(request, path_components) 495 except Http404: 496 return 497 498 if queryset.filter(id=page.id).exists(): 499 return page 500 501 return super().find_object(queryset, request) 502 503 def get_serializer_context(self): 504 """ 505 The serialization context differs between listing and detail views. 506 """ 507 context = super().get_serializer_context() 508 context['base_queryset'] = self.get_base_queryset() 509 return context 510