1""" 2generators.py # Top-down schema generation 3 4See schemas.__init__.py for package overview. 5""" 6import re 7from importlib import import_module 8 9from django.conf import settings 10from django.contrib.admindocs.views import simplify_regex 11from django.core.exceptions import PermissionDenied 12from django.http import Http404 13from django.urls import URLPattern, URLResolver 14 15from rest_framework import exceptions 16from rest_framework.request import clone_request 17from rest_framework.settings import api_settings 18from rest_framework.utils.model_meta import _get_pk 19 20 21def get_pk_name(model): 22 meta = model._meta.concrete_model._meta 23 return _get_pk(meta).name 24 25 26def is_api_view(callback): 27 """ 28 Return `True` if the given view callback is a REST framework view/viewset. 29 """ 30 # Avoid import cycle on APIView 31 from rest_framework.views import APIView 32 cls = getattr(callback, 'cls', None) 33 return (cls is not None) and issubclass(cls, APIView) 34 35 36def endpoint_ordering(endpoint): 37 path, method, callback = endpoint 38 method_priority = { 39 'GET': 0, 40 'POST': 1, 41 'PUT': 2, 42 'PATCH': 3, 43 'DELETE': 4 44 }.get(method, 5) 45 return (method_priority,) 46 47 48_PATH_PARAMETER_COMPONENT_RE = re.compile( 49 r'<(?:(?P<converter>[^>:]+):)?(?P<parameter>\w+)>' 50) 51 52 53class EndpointEnumerator: 54 """ 55 A class to determine the available API endpoints that a project exposes. 56 """ 57 def __init__(self, patterns=None, urlconf=None): 58 if patterns is None: 59 if urlconf is None: 60 # Use the default Django URL conf 61 urlconf = settings.ROOT_URLCONF 62 63 # Load the given URLconf module 64 if isinstance(urlconf, str): 65 urls = import_module(urlconf) 66 else: 67 urls = urlconf 68 patterns = urls.urlpatterns 69 70 self.patterns = patterns 71 72 def get_api_endpoints(self, patterns=None, prefix=''): 73 """ 74 Return a list of all available API endpoints by inspecting the URL conf. 75 """ 76 if patterns is None: 77 patterns = self.patterns 78 79 api_endpoints = [] 80 81 for pattern in patterns: 82 path_regex = prefix + str(pattern.pattern) 83 if isinstance(pattern, URLPattern): 84 path = self.get_path_from_regex(path_regex) 85 callback = pattern.callback 86 if self.should_include_endpoint(path, callback): 87 for method in self.get_allowed_methods(callback): 88 endpoint = (path, method, callback) 89 api_endpoints.append(endpoint) 90 91 elif isinstance(pattern, URLResolver): 92 nested_endpoints = self.get_api_endpoints( 93 patterns=pattern.url_patterns, 94 prefix=path_regex 95 ) 96 api_endpoints.extend(nested_endpoints) 97 98 return sorted(api_endpoints, key=endpoint_ordering) 99 100 def get_path_from_regex(self, path_regex): 101 """ 102 Given a URL conf regex, return a URI template string. 103 """ 104 # ???: Would it be feasible to adjust this such that we generate the 105 # path, plus the kwargs, plus the type from the convertor, such that we 106 # could feed that straight into the parameter schema object? 107 108 path = simplify_regex(path_regex) 109 110 # Strip Django 2.0 convertors as they are incompatible with uritemplate format 111 return re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g<parameter>}', path) 112 113 def should_include_endpoint(self, path, callback): 114 """ 115 Return `True` if the given endpoint should be included. 116 """ 117 if not is_api_view(callback): 118 return False # Ignore anything except REST framework views. 119 120 if callback.cls.schema is None: 121 return False 122 123 if 'schema' in callback.initkwargs: 124 if callback.initkwargs['schema'] is None: 125 return False 126 127 if path.endswith('.{format}') or path.endswith('.{format}/'): 128 return False # Ignore .json style URLs. 129 130 return True 131 132 def get_allowed_methods(self, callback): 133 """ 134 Return a list of the valid HTTP methods for this endpoint. 135 """ 136 if hasattr(callback, 'actions'): 137 actions = set(callback.actions) 138 http_method_names = set(callback.cls.http_method_names) 139 methods = [method.upper() for method in actions & http_method_names] 140 else: 141 methods = callback.cls().allowed_methods 142 143 return [method for method in methods if method not in ('OPTIONS', 'HEAD')] 144 145 146class BaseSchemaGenerator: 147 endpoint_inspector_cls = EndpointEnumerator 148 149 # 'pk' isn't great as an externally exposed name for an identifier, 150 # so by default we prefer to use the actual model field name for schemas. 151 # Set by 'SCHEMA_COERCE_PATH_PK'. 152 coerce_path_pk = None 153 154 def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None): 155 if url and not url.endswith('/'): 156 url += '/' 157 158 self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK 159 160 self.patterns = patterns 161 self.urlconf = urlconf 162 self.title = title 163 self.description = description 164 self.version = version 165 self.url = url 166 self.endpoints = None 167 168 def _initialise_endpoints(self): 169 if self.endpoints is None: 170 inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) 171 self.endpoints = inspector.get_api_endpoints() 172 173 def _get_paths_and_endpoints(self, request): 174 """ 175 Generate (path, method, view) given (path, method, callback) for paths. 176 """ 177 paths = [] 178 view_endpoints = [] 179 for path, method, callback in self.endpoints: 180 view = self.create_view(callback, method, request) 181 path = self.coerce_path(path, method, view) 182 paths.append(path) 183 view_endpoints.append((path, method, view)) 184 185 return paths, view_endpoints 186 187 def create_view(self, callback, method, request=None): 188 """ 189 Given a callback, return an actual view instance. 190 """ 191 view = callback.cls(**getattr(callback, 'initkwargs', {})) 192 view.args = () 193 view.kwargs = {} 194 view.format_kwarg = None 195 view.request = None 196 view.action_map = getattr(callback, 'actions', None) 197 198 actions = getattr(callback, 'actions', None) 199 if actions is not None: 200 if method == 'OPTIONS': 201 view.action = 'metadata' 202 else: 203 view.action = actions.get(method.lower()) 204 205 if request is not None: 206 view.request = clone_request(request, method) 207 208 return view 209 210 def coerce_path(self, path, method, view): 211 """ 212 Coerce {pk} path arguments into the name of the model field, 213 where possible. This is cleaner for an external representation. 214 (Ie. "this is an identifier", not "this is a database primary key") 215 """ 216 if not self.coerce_path_pk or '{pk}' not in path: 217 return path 218 model = getattr(getattr(view, 'queryset', None), 'model', None) 219 if model: 220 field_name = get_pk_name(model) 221 else: 222 field_name = 'id' 223 return path.replace('{pk}', '{%s}' % field_name) 224 225 def get_schema(self, request=None, public=False): 226 raise NotImplementedError(".get_schema() must be implemented in subclasses.") 227 228 def has_view_permissions(self, path, method, view): 229 """ 230 Return `True` if the incoming request has the correct view permissions. 231 """ 232 if view.request is None: 233 return True 234 235 try: 236 view.check_permissions(view.request) 237 except (exceptions.APIException, Http404, PermissionDenied): 238 return False 239 return True 240