1import logging
2from collections import OrderedDict
3
4from rest_framework.request import is_form_media_type
5from rest_framework.schemas import AutoSchema
6from rest_framework.status import is_success
7
8from .. import openapi
9from ..errors import SwaggerGenerationError
10from ..utils import (
11    filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status,
12    merge_params, no_body, param_list_to_odict
13)
14from .base import ViewInspector, call_view_method
15
16logger = logging.getLogger(__name__)
17
18
19class SwaggerAutoSchema(ViewInspector):
20    def __init__(self, view, path, method, components, request, overrides, operation_keys=None):
21        super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides)
22        self._sch = AutoSchema()
23        self._sch.view = view
24        self.operation_keys = operation_keys
25
26    def get_operation(self, operation_keys=None):
27        operation_keys = operation_keys or self.operation_keys
28
29        consumes = self.get_consumes()
30        produces = self.get_produces()
31
32        body = self.get_request_body_parameters(consumes)
33        query = self.get_query_parameters()
34        parameters = body + query
35        parameters = filter_none(parameters)
36        parameters = self.add_manual_parameters(parameters)
37
38        operation_id = self.get_operation_id(operation_keys)
39        summary, description = self.get_summary_and_description()
40        security = self.get_security()
41        assert security is None or isinstance(security, list), "security must be a list of security requirement objects"
42        deprecated = self.is_deprecated()
43        tags = self.get_tags(operation_keys)
44
45        responses = self.get_responses()
46
47        return openapi.Operation(
48            operation_id=operation_id,
49            description=force_real_str(description),
50            summary=force_real_str(summary),
51            responses=responses,
52            parameters=parameters,
53            consumes=consumes,
54            produces=produces,
55            tags=tags,
56            security=security,
57            deprecated=deprecated
58        )
59
60    def get_request_body_parameters(self, consumes):
61        """Return the request body parameters for this view. |br|
62        This is either:
63
64        -  a list with a single object Parameter with a :class:`.Schema` derived from the request serializer
65        -  a list of primitive Parameters parsed as form data
66
67        :param list[str] consumes: a list of accepted MIME types as returned by :meth:`.get_consumes`
68        :return: a (potentially empty) list of :class:`.Parameter`\\ s either ``in: body`` or ``in: formData``
69        :rtype: list[openapi.Parameter]
70        """
71        serializer = self.get_request_serializer()
72        schema = None
73        if serializer is None:
74            return []
75
76        if isinstance(serializer, openapi.Schema.OR_REF):
77            schema = serializer
78
79        if any(is_form_media_type(encoding) for encoding in consumes):
80            if schema is not None:
81                raise SwaggerGenerationError("form request body cannot be a Schema")
82            return self.get_request_form_parameters(serializer)
83        else:
84            if schema is None:
85                schema = self.get_request_body_schema(serializer)
86            return [self.make_body_parameter(schema)] if schema is not None else []
87
88    def get_view_serializer(self):
89        """Return the serializer as defined by the view's ``get_serializer()`` method.
90
91        :return: the view's ``Serializer``
92        :rtype: rest_framework.serializers.Serializer
93        """
94        return call_view_method(self.view, 'get_serializer')
95
96    def _get_request_body_override(self):
97        """Parse the request_body key in the override dict. This method is not public API."""
98        body_override = self.overrides.get('request_body', None)
99
100        if body_override is not None:
101            if body_override is no_body:
102                return no_body
103            if self.method not in self.body_methods:
104                raise SwaggerGenerationError("request_body can only be applied to (" + ','.join(self.body_methods) +
105                                             "); are you looking for query_serializer or manual_parameters?")
106            if isinstance(body_override, openapi.Schema.OR_REF):
107                return body_override
108            return force_serializer_instance(body_override)
109
110        return body_override
111
112    def get_request_serializer(self):
113        """Return the request serializer (used for parsing the request payload) for this endpoint.
114
115        :return: the request serializer, or one of :class:`.Schema`, :class:`.SchemaRef`, ``None``
116        :rtype: rest_framework.serializers.Serializer
117        """
118        body_override = self._get_request_body_override()
119
120        if body_override is None and self.method in self.implicit_body_methods:
121            return self.get_view_serializer()
122
123        if body_override is no_body:
124            return None
125
126        return body_override
127
128    def get_request_form_parameters(self, serializer):
129        """Given a Serializer, return a list of ``in: formData`` :class:`.Parameter`\\ s.
130
131        :param serializer: the view's request serializer as returned by :meth:`.get_request_serializer`
132        :rtype: list[openapi.Parameter]
133        """
134        return self.serializer_to_parameters(serializer, in_=openapi.IN_FORM)
135
136    def get_request_body_schema(self, serializer):
137        """Return the :class:`.Schema` for a given request's body data. Only applies to PUT, PATCH and POST requests.
138
139        :param serializer: the view's request serializer as returned by :meth:`.get_request_serializer`
140        :rtype: openapi.Schema
141        """
142        return self.serializer_to_schema(serializer)
143
144    def make_body_parameter(self, schema):
145        """Given a :class:`.Schema` object, create an ``in: body`` :class:`.Parameter`.
146
147        :param openapi.Schema schema: the request body schema
148        :rtype: openapi.Parameter
149        """
150        return openapi.Parameter(name='data', in_=openapi.IN_BODY, required=True, schema=schema)
151
152    def add_manual_parameters(self, parameters):
153        """Add/replace parameters from the given list of automatically generated request parameters.
154
155        :param list[openapi.Parameter] parameters: genereated parameters
156        :return: modified parameters
157        :rtype: list[openapi.Parameter]
158        """
159        manual_parameters = self.overrides.get('manual_parameters', None) or []
160
161        if any(param.in_ == openapi.IN_BODY for param in manual_parameters):  # pragma: no cover
162            raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body")
163        if any(param.in_ == openapi.IN_FORM for param in manual_parameters):  # pragma: no cover
164            has_body_parameter = any(param.in_ == openapi.IN_BODY for param in parameters)
165            if has_body_parameter or not any(is_form_media_type(encoding) for encoding in self.get_consumes()):
166                raise SwaggerGenerationError("cannot add form parameters when the request has a request body; "
167                                             "did you forget to set an appropriate parser class on the view?")
168            if self.method not in self.body_methods:
169                raise SwaggerGenerationError("form parameters can only be applied to "
170                                             "(" + ','.join(self.body_methods) + ") HTTP methods")
171
172        return merge_params(parameters, manual_parameters)
173
174    def get_responses(self):
175        """Get the possible responses for this view as a swagger :class:`.Responses` object.
176
177        :return: the documented responses
178        :rtype: openapi.Responses
179        """
180        response_serializers = self.get_response_serializers()
181        return openapi.Responses(
182            responses=self.get_response_schemas(response_serializers)
183        )
184
185    def get_default_response_serializer(self):
186        """Return the default response serializer for this endpoint. This is derived from either the ``request_body``
187        override or the request serializer (:meth:`.get_view_serializer`).
188
189        :return: response serializer, :class:`.Schema`, :class:`.SchemaRef`, ``None``
190        """
191        body_override = self._get_request_body_override()
192        if body_override and body_override is not no_body:
193            return body_override
194
195        return self.get_view_serializer()
196
197    def get_default_responses(self):
198        """Get the default responses determined for this view from the request serializer and request method.
199
200        :type: dict[str, openapi.Schema]
201        """
202        method = self.method.lower()
203
204        default_status = guess_response_status(method)
205        default_schema = ''
206        if method in ('get', 'post', 'put', 'patch'):
207            default_schema = self.get_default_response_serializer()
208
209        default_schema = default_schema or ''
210        if default_schema and not isinstance(default_schema, openapi.Schema):
211            default_schema = self.serializer_to_schema(default_schema) or ''
212
213        if default_schema:
214            if self.has_list_response():
215                default_schema = openapi.Schema(type=openapi.TYPE_ARRAY, items=default_schema)
216            if self.should_page():
217                default_schema = self.get_paginated_response(default_schema) or default_schema
218
219        return OrderedDict({str(default_status): default_schema})
220
221    def get_response_serializers(self):
222        """Return the response codes that this view is expected to return, and the serializer for each response body.
223        The return value should be a dict where the keys are possible status codes, and values are either strings,
224        ``Serializer``\\ s, :class:`.Schema`, :class:`.SchemaRef` or :class:`.Response` objects. See
225        :func:`@swagger_auto_schema <.swagger_auto_schema>` for more details.
226
227        :return: the response serializers
228        :rtype: dict
229        """
230        manual_responses = self.overrides.get('responses', None) or {}
231        manual_responses = OrderedDict((str(sc), resp) for sc, resp in manual_responses.items())
232
233        responses = OrderedDict()
234        if not any(is_success(int(sc)) for sc in manual_responses if sc != 'default'):
235            responses = self.get_default_responses()
236
237        responses.update((str(sc), resp) for sc, resp in manual_responses.items())
238        return responses
239
240    def get_response_schemas(self, response_serializers):
241        """Return the :class:`.openapi.Response` objects calculated for this view.
242
243        :param dict response_serializers: response serializers as returned by :meth:`.get_response_serializers`
244        :return: a dictionary of status code to :class:`.Response` object
245        :rtype: dict[str, openapi.Response]
246        """
247        responses = OrderedDict()
248        for sc, serializer in response_serializers.items():
249            if isinstance(serializer, str):
250                response = openapi.Response(
251                    description=force_real_str(serializer)
252                )
253            elif not serializer:
254                continue
255            elif isinstance(serializer, openapi.Response):
256                response = serializer
257                if hasattr(response, 'schema') and not isinstance(response.schema, openapi.Schema.OR_REF):
258                    serializer = force_serializer_instance(response.schema)
259                    response.schema = self.serializer_to_schema(serializer)
260            elif isinstance(serializer, openapi.Schema.OR_REF):
261                response = openapi.Response(
262                    description='',
263                    schema=serializer,
264                )
265            else:
266                serializer = force_serializer_instance(serializer)
267                response = openapi.Response(
268                    description='',
269                    schema=self.serializer_to_schema(serializer),
270                )
271
272            responses[str(sc)] = response
273
274        return responses
275
276    def get_query_serializer(self):
277        """Return the query serializer (used for parsing query parameters) for this endpoint.
278
279        :return: the query serializer, or ``None``
280        """
281        query_serializer = self.overrides.get('query_serializer', None)
282        if query_serializer is not None:
283            query_serializer = force_serializer_instance(query_serializer)
284        return query_serializer
285
286    def get_query_parameters(self):
287        """Return the query parameters accepted by this view.
288
289        :rtype: list[openapi.Parameter]
290        """
291        natural_parameters = self.get_filter_parameters() + self.get_pagination_parameters()
292
293        query_serializer = self.get_query_serializer()
294        serializer_parameters = []
295        if query_serializer is not None:
296            serializer_parameters = self.serializer_to_parameters(query_serializer, in_=openapi.IN_QUERY)
297
298            if len(set(param_list_to_odict(natural_parameters)) & set(param_list_to_odict(serializer_parameters))) != 0:
299                raise SwaggerGenerationError(
300                    "your query_serializer contains fields that conflict with the "
301                    "filter_backend or paginator_class on the view - %s %s" % (self.method, self.path)
302                )
303
304        return natural_parameters + serializer_parameters
305
306    def get_operation_id(self, operation_keys=None):
307        """Return an unique ID for this operation. The ID must be unique across
308        all :class:`.Operation` objects in the API.
309
310        :param tuple[str] operation_keys: an array of keys derived from the pathdescribing the hierarchical layout
311            of this view in the API; e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc.
312        :rtype: str
313        """
314        operation_keys = operation_keys or self.operation_keys
315
316        operation_id = self.overrides.get('operation_id', '')
317        if not operation_id:
318            operation_id = '_'.join(operation_keys)
319        return operation_id
320
321    def split_summary_from_description(self, description):
322        """Decide if and how to split a summary out of the given description. The default implementation
323        uses the first paragraph of the description as a summary if it is less than 120 characters long.
324
325        :param description: the full description to be analyzed
326        :return: summary and description
327        :rtype: (str,str)
328        """
329        # https://www.python.org/dev/peps/pep-0257/#multi-line-docstrings
330        summary = None
331        summary_max_len = 120  # OpenAPI 2.0 spec says summary should be under 120 characters
332        sections = description.split('\n\n', 1)
333        if len(sections) == 2:
334            sections[0] = sections[0].strip()
335            if len(sections[0]) < summary_max_len:
336                summary, description = sections
337                description = description.strip()
338
339        return summary, description
340
341    def get_summary_and_description(self):
342        """Return an operation summary and description determined from the view's docstring.
343
344        :return: summary and description
345        :rtype: (str,str)
346        """
347        description = self.overrides.get('operation_description', None)
348        summary = self.overrides.get('operation_summary', None)
349        if description is None:
350            description = self._sch.get_description(self.path, self.method) or ''
351            description = description.strip().replace('\r', '')
352
353            if description and (summary is None):
354                # description from docstring... do summary magic
355                summary, description = self.split_summary_from_description(description)
356
357        return summary, description
358
359    def get_security(self):
360        """Return a list of security requirements for this operation.
361
362        Returning an empty list marks the endpoint as unauthenticated (i.e. removes all accepted
363        authentication schemes). Returning ``None`` will inherit the top-level secuirty requirements.
364
365        :return: security requirements
366        :rtype: list[dict[str,list[str]]]"""
367        return self.overrides.get('security', None)
368
369    def is_deprecated(self):
370        """Return ``True`` if this operation is to be marked as deprecated.
371
372        :return: deprecation status
373        :rtype: bool
374        """
375        return self.overrides.get('deprecated', None)
376
377    def get_tags(self, operation_keys=None):
378        """Get a list of tags for this operation. Tags determine how operations relate with each other, and in the UI
379        each tag will show as a group containing the operations that use it. If not provided in overrides,
380        tags will be inferred from the operation url.
381
382        :param tuple[str] operation_keys: an array of keys derived from the pathdescribing the hierarchical layout
383            of this view in the API; e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc.
384        :rtype: list[str]
385        """
386        operation_keys = operation_keys or self.operation_keys
387
388        tags = self.overrides.get('tags')
389        if not tags:
390            tags = [operation_keys[0]]
391
392        return tags
393
394    def get_consumes(self):
395        """Return the MIME types this endpoint can consume.
396
397        :rtype: list[str]
398        """
399        return get_consumes(self.get_parser_classes())
400
401    def get_produces(self):
402        """Return the MIME types this endpoint can produce.
403
404        :rtype: list[str]
405        """
406        return get_produces(self.get_renderer_classes())
407