1import logging 2from collections import OrderedDict 3 4from rest_framework.request import is_form_media_type 5from rest_framework.schemas import AutoSchema 6from rest_framework.status import is_success 7 8from .. import openapi 9from ..errors import SwaggerGenerationError 10from ..utils import ( 11 filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status, 12 merge_params, no_body, param_list_to_odict 13) 14from .base import ViewInspector, call_view_method 15 16logger = logging.getLogger(__name__) 17 18 19class SwaggerAutoSchema(ViewInspector): 20 def __init__(self, view, path, method, components, request, overrides, operation_keys=None): 21 super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides) 22 self._sch = AutoSchema() 23 self._sch.view = view 24 self.operation_keys = operation_keys 25 26 def get_operation(self, operation_keys=None): 27 operation_keys = operation_keys or self.operation_keys 28 29 consumes = self.get_consumes() 30 produces = self.get_produces() 31 32 body = self.get_request_body_parameters(consumes) 33 query = self.get_query_parameters() 34 parameters = body + query 35 parameters = filter_none(parameters) 36 parameters = self.add_manual_parameters(parameters) 37 38 operation_id = self.get_operation_id(operation_keys) 39 summary, description = self.get_summary_and_description() 40 security = self.get_security() 41 assert security is None or isinstance(security, list), "security must be a list of security requirement objects" 42 deprecated = self.is_deprecated() 43 tags = self.get_tags(operation_keys) 44 45 responses = self.get_responses() 46 47 return openapi.Operation( 48 operation_id=operation_id, 49 description=force_real_str(description), 50 summary=force_real_str(summary), 51 responses=responses, 52 parameters=parameters, 53 consumes=consumes, 54 produces=produces, 55 tags=tags, 56 security=security, 57 deprecated=deprecated 58 ) 59 60 def get_request_body_parameters(self, consumes): 61 """Return the request body parameters for this view. |br| 62 This is either: 63 64 - a list with a single object Parameter with a :class:`.Schema` derived from the request serializer 65 - a list of primitive Parameters parsed as form data 66 67 :param list[str] consumes: a list of accepted MIME types as returned by :meth:`.get_consumes` 68 :return: a (potentially empty) list of :class:`.Parameter`\\ s either ``in: body`` or ``in: formData`` 69 :rtype: list[openapi.Parameter] 70 """ 71 serializer = self.get_request_serializer() 72 schema = None 73 if serializer is None: 74 return [] 75 76 if isinstance(serializer, openapi.Schema.OR_REF): 77 schema = serializer 78 79 if any(is_form_media_type(encoding) for encoding in consumes): 80 if schema is not None: 81 raise SwaggerGenerationError("form request body cannot be a Schema") 82 return self.get_request_form_parameters(serializer) 83 else: 84 if schema is None: 85 schema = self.get_request_body_schema(serializer) 86 return [self.make_body_parameter(schema)] if schema is not None else [] 87 88 def get_view_serializer(self): 89 """Return the serializer as defined by the view's ``get_serializer()`` method. 90 91 :return: the view's ``Serializer`` 92 :rtype: rest_framework.serializers.Serializer 93 """ 94 return call_view_method(self.view, 'get_serializer') 95 96 def _get_request_body_override(self): 97 """Parse the request_body key in the override dict. This method is not public API.""" 98 body_override = self.overrides.get('request_body', None) 99 100 if body_override is not None: 101 if body_override is no_body: 102 return no_body 103 if self.method not in self.body_methods: 104 raise SwaggerGenerationError("request_body can only be applied to (" + ','.join(self.body_methods) + 105 "); are you looking for query_serializer or manual_parameters?") 106 if isinstance(body_override, openapi.Schema.OR_REF): 107 return body_override 108 return force_serializer_instance(body_override) 109 110 return body_override 111 112 def get_request_serializer(self): 113 """Return the request serializer (used for parsing the request payload) for this endpoint. 114 115 :return: the request serializer, or one of :class:`.Schema`, :class:`.SchemaRef`, ``None`` 116 :rtype: rest_framework.serializers.Serializer 117 """ 118 body_override = self._get_request_body_override() 119 120 if body_override is None and self.method in self.implicit_body_methods: 121 return self.get_view_serializer() 122 123 if body_override is no_body: 124 return None 125 126 return body_override 127 128 def get_request_form_parameters(self, serializer): 129 """Given a Serializer, return a list of ``in: formData`` :class:`.Parameter`\\ s. 130 131 :param serializer: the view's request serializer as returned by :meth:`.get_request_serializer` 132 :rtype: list[openapi.Parameter] 133 """ 134 return self.serializer_to_parameters(serializer, in_=openapi.IN_FORM) 135 136 def get_request_body_schema(self, serializer): 137 """Return the :class:`.Schema` for a given request's body data. Only applies to PUT, PATCH and POST requests. 138 139 :param serializer: the view's request serializer as returned by :meth:`.get_request_serializer` 140 :rtype: openapi.Schema 141 """ 142 return self.serializer_to_schema(serializer) 143 144 def make_body_parameter(self, schema): 145 """Given a :class:`.Schema` object, create an ``in: body`` :class:`.Parameter`. 146 147 :param openapi.Schema schema: the request body schema 148 :rtype: openapi.Parameter 149 """ 150 return openapi.Parameter(name='data', in_=openapi.IN_BODY, required=True, schema=schema) 151 152 def add_manual_parameters(self, parameters): 153 """Add/replace parameters from the given list of automatically generated request parameters. 154 155 :param list[openapi.Parameter] parameters: genereated parameters 156 :return: modified parameters 157 :rtype: list[openapi.Parameter] 158 """ 159 manual_parameters = self.overrides.get('manual_parameters', None) or [] 160 161 if any(param.in_ == openapi.IN_BODY for param in manual_parameters): # pragma: no cover 162 raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body") 163 if any(param.in_ == openapi.IN_FORM for param in manual_parameters): # pragma: no cover 164 has_body_parameter = any(param.in_ == openapi.IN_BODY for param in parameters) 165 if has_body_parameter or not any(is_form_media_type(encoding) for encoding in self.get_consumes()): 166 raise SwaggerGenerationError("cannot add form parameters when the request has a request body; " 167 "did you forget to set an appropriate parser class on the view?") 168 if self.method not in self.body_methods: 169 raise SwaggerGenerationError("form parameters can only be applied to " 170 "(" + ','.join(self.body_methods) + ") HTTP methods") 171 172 return merge_params(parameters, manual_parameters) 173 174 def get_responses(self): 175 """Get the possible responses for this view as a swagger :class:`.Responses` object. 176 177 :return: the documented responses 178 :rtype: openapi.Responses 179 """ 180 response_serializers = self.get_response_serializers() 181 return openapi.Responses( 182 responses=self.get_response_schemas(response_serializers) 183 ) 184 185 def get_default_response_serializer(self): 186 """Return the default response serializer for this endpoint. This is derived from either the ``request_body`` 187 override or the request serializer (:meth:`.get_view_serializer`). 188 189 :return: response serializer, :class:`.Schema`, :class:`.SchemaRef`, ``None`` 190 """ 191 body_override = self._get_request_body_override() 192 if body_override and body_override is not no_body: 193 return body_override 194 195 return self.get_view_serializer() 196 197 def get_default_responses(self): 198 """Get the default responses determined for this view from the request serializer and request method. 199 200 :type: dict[str, openapi.Schema] 201 """ 202 method = self.method.lower() 203 204 default_status = guess_response_status(method) 205 default_schema = '' 206 if method in ('get', 'post', 'put', 'patch'): 207 default_schema = self.get_default_response_serializer() 208 209 default_schema = default_schema or '' 210 if default_schema and not isinstance(default_schema, openapi.Schema): 211 default_schema = self.serializer_to_schema(default_schema) or '' 212 213 if default_schema: 214 if self.has_list_response(): 215 default_schema = openapi.Schema(type=openapi.TYPE_ARRAY, items=default_schema) 216 if self.should_page(): 217 default_schema = self.get_paginated_response(default_schema) or default_schema 218 219 return OrderedDict({str(default_status): default_schema}) 220 221 def get_response_serializers(self): 222 """Return the response codes that this view is expected to return, and the serializer for each response body. 223 The return value should be a dict where the keys are possible status codes, and values are either strings, 224 ``Serializer``\\ s, :class:`.Schema`, :class:`.SchemaRef` or :class:`.Response` objects. See 225 :func:`@swagger_auto_schema <.swagger_auto_schema>` for more details. 226 227 :return: the response serializers 228 :rtype: dict 229 """ 230 manual_responses = self.overrides.get('responses', None) or {} 231 manual_responses = OrderedDict((str(sc), resp) for sc, resp in manual_responses.items()) 232 233 responses = OrderedDict() 234 if not any(is_success(int(sc)) for sc in manual_responses if sc != 'default'): 235 responses = self.get_default_responses() 236 237 responses.update((str(sc), resp) for sc, resp in manual_responses.items()) 238 return responses 239 240 def get_response_schemas(self, response_serializers): 241 """Return the :class:`.openapi.Response` objects calculated for this view. 242 243 :param dict response_serializers: response serializers as returned by :meth:`.get_response_serializers` 244 :return: a dictionary of status code to :class:`.Response` object 245 :rtype: dict[str, openapi.Response] 246 """ 247 responses = OrderedDict() 248 for sc, serializer in response_serializers.items(): 249 if isinstance(serializer, str): 250 response = openapi.Response( 251 description=force_real_str(serializer) 252 ) 253 elif not serializer: 254 continue 255 elif isinstance(serializer, openapi.Response): 256 response = serializer 257 if hasattr(response, 'schema') and not isinstance(response.schema, openapi.Schema.OR_REF): 258 serializer = force_serializer_instance(response.schema) 259 response.schema = self.serializer_to_schema(serializer) 260 elif isinstance(serializer, openapi.Schema.OR_REF): 261 response = openapi.Response( 262 description='', 263 schema=serializer, 264 ) 265 else: 266 serializer = force_serializer_instance(serializer) 267 response = openapi.Response( 268 description='', 269 schema=self.serializer_to_schema(serializer), 270 ) 271 272 responses[str(sc)] = response 273 274 return responses 275 276 def get_query_serializer(self): 277 """Return the query serializer (used for parsing query parameters) for this endpoint. 278 279 :return: the query serializer, or ``None`` 280 """ 281 query_serializer = self.overrides.get('query_serializer', None) 282 if query_serializer is not None: 283 query_serializer = force_serializer_instance(query_serializer) 284 return query_serializer 285 286 def get_query_parameters(self): 287 """Return the query parameters accepted by this view. 288 289 :rtype: list[openapi.Parameter] 290 """ 291 natural_parameters = self.get_filter_parameters() + self.get_pagination_parameters() 292 293 query_serializer = self.get_query_serializer() 294 serializer_parameters = [] 295 if query_serializer is not None: 296 serializer_parameters = self.serializer_to_parameters(query_serializer, in_=openapi.IN_QUERY) 297 298 if len(set(param_list_to_odict(natural_parameters)) & set(param_list_to_odict(serializer_parameters))) != 0: 299 raise SwaggerGenerationError( 300 "your query_serializer contains fields that conflict with the " 301 "filter_backend or paginator_class on the view - %s %s" % (self.method, self.path) 302 ) 303 304 return natural_parameters + serializer_parameters 305 306 def get_operation_id(self, operation_keys=None): 307 """Return an unique ID for this operation. The ID must be unique across 308 all :class:`.Operation` objects in the API. 309 310 :param tuple[str] operation_keys: an array of keys derived from the pathdescribing the hierarchical layout 311 of this view in the API; e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc. 312 :rtype: str 313 """ 314 operation_keys = operation_keys or self.operation_keys 315 316 operation_id = self.overrides.get('operation_id', '') 317 if not operation_id: 318 operation_id = '_'.join(operation_keys) 319 return operation_id 320 321 def split_summary_from_description(self, description): 322 """Decide if and how to split a summary out of the given description. The default implementation 323 uses the first paragraph of the description as a summary if it is less than 120 characters long. 324 325 :param description: the full description to be analyzed 326 :return: summary and description 327 :rtype: (str,str) 328 """ 329 # https://www.python.org/dev/peps/pep-0257/#multi-line-docstrings 330 summary = None 331 summary_max_len = 120 # OpenAPI 2.0 spec says summary should be under 120 characters 332 sections = description.split('\n\n', 1) 333 if len(sections) == 2: 334 sections[0] = sections[0].strip() 335 if len(sections[0]) < summary_max_len: 336 summary, description = sections 337 description = description.strip() 338 339 return summary, description 340 341 def get_summary_and_description(self): 342 """Return an operation summary and description determined from the view's docstring. 343 344 :return: summary and description 345 :rtype: (str,str) 346 """ 347 description = self.overrides.get('operation_description', None) 348 summary = self.overrides.get('operation_summary', None) 349 if description is None: 350 description = self._sch.get_description(self.path, self.method) or '' 351 description = description.strip().replace('\r', '') 352 353 if description and (summary is None): 354 # description from docstring... do summary magic 355 summary, description = self.split_summary_from_description(description) 356 357 return summary, description 358 359 def get_security(self): 360 """Return a list of security requirements for this operation. 361 362 Returning an empty list marks the endpoint as unauthenticated (i.e. removes all accepted 363 authentication schemes). Returning ``None`` will inherit the top-level secuirty requirements. 364 365 :return: security requirements 366 :rtype: list[dict[str,list[str]]]""" 367 return self.overrides.get('security', None) 368 369 def is_deprecated(self): 370 """Return ``True`` if this operation is to be marked as deprecated. 371 372 :return: deprecation status 373 :rtype: bool 374 """ 375 return self.overrides.get('deprecated', None) 376 377 def get_tags(self, operation_keys=None): 378 """Get a list of tags for this operation. Tags determine how operations relate with each other, and in the UI 379 each tag will show as a group containing the operations that use it. If not provided in overrides, 380 tags will be inferred from the operation url. 381 382 :param tuple[str] operation_keys: an array of keys derived from the pathdescribing the hierarchical layout 383 of this view in the API; e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc. 384 :rtype: list[str] 385 """ 386 operation_keys = operation_keys or self.operation_keys 387 388 tags = self.overrides.get('tags') 389 if not tags: 390 tags = [operation_keys[0]] 391 392 return tags 393 394 def get_consumes(self): 395 """Return the MIME types this endpoint can consume. 396 397 :rtype: list[str] 398 """ 399 return get_consumes(self.get_parser_classes()) 400 401 def get_produces(self): 402 """Return the MIME types this endpoint can produce. 403 404 :rtype: list[str] 405 """ 406 return get_produces(self.get_renderer_classes()) 407