1"""
2generators.py   # Top-down schema generation
3
4See schemas.__init__.py for package overview.
5"""
6import re
7from importlib import import_module
8
9from django.conf import settings
10from django.contrib.admindocs.views import simplify_regex
11from django.core.exceptions import PermissionDenied
12from django.http import Http404
13from django.urls import URLPattern, URLResolver
14
15from rest_framework import exceptions
16from rest_framework.request import clone_request
17from rest_framework.settings import api_settings
18from rest_framework.utils.model_meta import _get_pk
19
20
21def get_pk_name(model):
22    meta = model._meta.concrete_model._meta
23    return _get_pk(meta).name
24
25
26def is_api_view(callback):
27    """
28    Return `True` if the given view callback is a REST framework view/viewset.
29    """
30    # Avoid import cycle on APIView
31    from rest_framework.views import APIView
32    cls = getattr(callback, 'cls', None)
33    return (cls is not None) and issubclass(cls, APIView)
34
35
36def endpoint_ordering(endpoint):
37    path, method, callback = endpoint
38    method_priority = {
39        'GET': 0,
40        'POST': 1,
41        'PUT': 2,
42        'PATCH': 3,
43        'DELETE': 4
44    }.get(method, 5)
45    return (method_priority,)
46
47
48_PATH_PARAMETER_COMPONENT_RE = re.compile(
49    r'<(?:(?P<converter>[^>:]+):)?(?P<parameter>\w+)>'
50)
51
52
53class EndpointEnumerator:
54    """
55    A class to determine the available API endpoints that a project exposes.
56    """
57    def __init__(self, patterns=None, urlconf=None):
58        if patterns is None:
59            if urlconf is None:
60                # Use the default Django URL conf
61                urlconf = settings.ROOT_URLCONF
62
63            # Load the given URLconf module
64            if isinstance(urlconf, str):
65                urls = import_module(urlconf)
66            else:
67                urls = urlconf
68            patterns = urls.urlpatterns
69
70        self.patterns = patterns
71
72    def get_api_endpoints(self, patterns=None, prefix=''):
73        """
74        Return a list of all available API endpoints by inspecting the URL conf.
75        """
76        if patterns is None:
77            patterns = self.patterns
78
79        api_endpoints = []
80
81        for pattern in patterns:
82            path_regex = prefix + str(pattern.pattern)
83            if isinstance(pattern, URLPattern):
84                path = self.get_path_from_regex(path_regex)
85                callback = pattern.callback
86                if self.should_include_endpoint(path, callback):
87                    for method in self.get_allowed_methods(callback):
88                        endpoint = (path, method, callback)
89                        api_endpoints.append(endpoint)
90
91            elif isinstance(pattern, URLResolver):
92                nested_endpoints = self.get_api_endpoints(
93                    patterns=pattern.url_patterns,
94                    prefix=path_regex
95                )
96                api_endpoints.extend(nested_endpoints)
97
98        return sorted(api_endpoints, key=endpoint_ordering)
99
100    def get_path_from_regex(self, path_regex):
101        """
102        Given a URL conf regex, return a URI template string.
103        """
104        # ???: Would it be feasible to adjust this such that we generate the
105        # path, plus the kwargs, plus the type from the convertor, such that we
106        # could feed that straight into the parameter schema object?
107
108        path = simplify_regex(path_regex)
109
110        # Strip Django 2.0 convertors as they are incompatible with uritemplate format
111        return re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g<parameter>}', path)
112
113    def should_include_endpoint(self, path, callback):
114        """
115        Return `True` if the given endpoint should be included.
116        """
117        if not is_api_view(callback):
118            return False  # Ignore anything except REST framework views.
119
120        if callback.cls.schema is None:
121            return False
122
123        if 'schema' in callback.initkwargs:
124            if callback.initkwargs['schema'] is None:
125                return False
126
127        if path.endswith('.{format}') or path.endswith('.{format}/'):
128            return False  # Ignore .json style URLs.
129
130        return True
131
132    def get_allowed_methods(self, callback):
133        """
134        Return a list of the valid HTTP methods for this endpoint.
135        """
136        if hasattr(callback, 'actions'):
137            actions = set(callback.actions)
138            http_method_names = set(callback.cls.http_method_names)
139            methods = [method.upper() for method in actions & http_method_names]
140        else:
141            methods = callback.cls().allowed_methods
142
143        return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
144
145
146class BaseSchemaGenerator:
147    endpoint_inspector_cls = EndpointEnumerator
148
149    # 'pk' isn't great as an externally exposed name for an identifier,
150    # so by default we prefer to use the actual model field name for schemas.
151    # Set by 'SCHEMA_COERCE_PATH_PK'.
152    coerce_path_pk = None
153
154    def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, version=None):
155        if url and not url.endswith('/'):
156            url += '/'
157
158        self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
159
160        self.patterns = patterns
161        self.urlconf = urlconf
162        self.title = title
163        self.description = description
164        self.version = version
165        self.url = url
166        self.endpoints = None
167
168    def _initialise_endpoints(self):
169        if self.endpoints is None:
170            inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf)
171            self.endpoints = inspector.get_api_endpoints()
172
173    def _get_paths_and_endpoints(self, request):
174        """
175        Generate (path, method, view) given (path, method, callback) for paths.
176        """
177        paths = []
178        view_endpoints = []
179        for path, method, callback in self.endpoints:
180            view = self.create_view(callback, method, request)
181            path = self.coerce_path(path, method, view)
182            paths.append(path)
183            view_endpoints.append((path, method, view))
184
185        return paths, view_endpoints
186
187    def create_view(self, callback, method, request=None):
188        """
189        Given a callback, return an actual view instance.
190        """
191        view = callback.cls(**getattr(callback, 'initkwargs', {}))
192        view.args = ()
193        view.kwargs = {}
194        view.format_kwarg = None
195        view.request = None
196        view.action_map = getattr(callback, 'actions', None)
197
198        actions = getattr(callback, 'actions', None)
199        if actions is not None:
200            if method == 'OPTIONS':
201                view.action = 'metadata'
202            else:
203                view.action = actions.get(method.lower())
204
205        if request is not None:
206            view.request = clone_request(request, method)
207
208        return view
209
210    def coerce_path(self, path, method, view):
211        """
212        Coerce {pk} path arguments into the name of the model field,
213        where possible. This is cleaner for an external representation.
214        (Ie. "this is an identifier", not "this is a database primary key")
215        """
216        if not self.coerce_path_pk or '{pk}' not in path:
217            return path
218        model = getattr(getattr(view, 'queryset', None), 'model', None)
219        if model:
220            field_name = get_pk_name(model)
221        else:
222            field_name = 'id'
223        return path.replace('{pk}', '{%s}' % field_name)
224
225    def get_schema(self, request=None, public=False):
226        raise NotImplementedError(".get_schema() must be implemented in subclasses.")
227
228    def has_view_permissions(self, path, method, view):
229        """
230        Return `True` if the incoming request has the correct view permissions.
231        """
232        if view.request is None:
233            return True
234
235        try:
236            view.check_permissions(view.request)
237        except (exceptions.APIException, Http404, PermissionDenied):
238            return False
239        return True
240