1import http.client
2import inspect
3from enum import Enum
4from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
5
6from fastapi import routing
7from fastapi.datastructures import DefaultPlaceholder
8from fastapi.dependencies.models import Dependant
9from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
10from fastapi.encoders import jsonable_encoder
11from fastapi.openapi.constants import (
12    METHODS_WITH_BODY,
13    REF_PREFIX,
14    STATUS_CODES_WITH_NO_BODY,
15)
16from fastapi.openapi.models import OpenAPI
17from fastapi.params import Body, Param
18from fastapi.responses import Response
19from fastapi.utils import (
20    deep_dict_update,
21    generate_operation_id_for_path,
22    get_model_definitions,
23)
24from pydantic import BaseModel
25from pydantic.fields import ModelField, Undefined
26from pydantic.schema import (
27    field_schema,
28    get_flat_models_from_fields,
29    get_model_name_map,
30)
31from pydantic.utils import lenient_issubclass
32from starlette.responses import JSONResponse
33from starlette.routing import BaseRoute
34from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
35
36validation_error_definition = {
37    "title": "ValidationError",
38    "type": "object",
39    "properties": {
40        "loc": {"title": "Location", "type": "array", "items": {"type": "string"}},
41        "msg": {"title": "Message", "type": "string"},
42        "type": {"title": "Error Type", "type": "string"},
43    },
44    "required": ["loc", "msg", "type"],
45}
46
47validation_error_response_definition = {
48    "title": "HTTPValidationError",
49    "type": "object",
50    "properties": {
51        "detail": {
52            "title": "Detail",
53            "type": "array",
54            "items": {"$ref": REF_PREFIX + "ValidationError"},
55        }
56    },
57}
58
59status_code_ranges: Dict[str, str] = {
60    "1XX": "Information",
61    "2XX": "Success",
62    "3XX": "Redirection",
63    "4XX": "Client Error",
64    "5XX": "Server Error",
65    "DEFAULT": "Default Response",
66}
67
68
69def get_openapi_security_definitions(
70    flat_dependant: Dependant,
71) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
72    security_definitions = {}
73    operation_security = []
74    for security_requirement in flat_dependant.security_requirements:
75        security_definition = jsonable_encoder(
76            security_requirement.security_scheme.model,
77            by_alias=True,
78            exclude_none=True,
79        )
80        security_name = security_requirement.security_scheme.scheme_name
81        security_definitions[security_name] = security_definition
82        operation_security.append({security_name: security_requirement.scopes})
83    return security_definitions, operation_security
84
85
86def get_openapi_operation_parameters(
87    *,
88    all_route_params: Sequence[ModelField],
89    model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
90) -> List[Dict[str, Any]]:
91    parameters = []
92    for param in all_route_params:
93        field_info = param.field_info
94        field_info = cast(Param, field_info)
95        parameter = {
96            "name": param.alias,
97            "in": field_info.in_.value,
98            "required": param.required,
99            "schema": field_schema(
100                param, model_name_map=model_name_map, ref_prefix=REF_PREFIX
101            )[0],
102        }
103        if field_info.description:
104            parameter["description"] = field_info.description
105        if field_info.examples:
106            parameter["examples"] = jsonable_encoder(field_info.examples)
107        elif field_info.example != Undefined:
108            parameter["example"] = jsonable_encoder(field_info.example)
109        if field_info.deprecated:
110            parameter["deprecated"] = field_info.deprecated
111        parameters.append(parameter)
112    return parameters
113
114
115def get_openapi_operation_request_body(
116    *,
117    body_field: Optional[ModelField],
118    model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
119) -> Optional[Dict[str, Any]]:
120    if not body_field:
121        return None
122    assert isinstance(body_field, ModelField)
123    body_schema, _, _ = field_schema(
124        body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
125    )
126    field_info = cast(Body, body_field.field_info)
127    request_media_type = field_info.media_type
128    required = body_field.required
129    request_body_oai: Dict[str, Any] = {}
130    if required:
131        request_body_oai["required"] = required
132    request_media_content: Dict[str, Any] = {"schema": body_schema}
133    if field_info.examples:
134        request_media_content["examples"] = jsonable_encoder(field_info.examples)
135    elif field_info.example != Undefined:
136        request_media_content["example"] = jsonable_encoder(field_info.example)
137    request_body_oai["content"] = {request_media_type: request_media_content}
138    return request_body_oai
139
140
141def generate_operation_id(*, route: routing.APIRoute, method: str) -> str:
142    if route.operation_id:
143        return route.operation_id
144    path: str = route.path_format
145    return generate_operation_id_for_path(name=route.name, path=path, method=method)
146
147
148def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
149    if route.summary:
150        return route.summary
151    return route.name.replace("_", " ").title()
152
153
154def get_openapi_operation_metadata(
155    *, route: routing.APIRoute, method: str
156) -> Dict[str, Any]:
157    operation: Dict[str, Any] = {}
158    if route.tags:
159        operation["tags"] = route.tags
160    operation["summary"] = generate_operation_summary(route=route, method=method)
161    if route.description:
162        operation["description"] = route.description
163    operation["operationId"] = generate_operation_id(route=route, method=method)
164    if route.deprecated:
165        operation["deprecated"] = route.deprecated
166    return operation
167
168
169def get_openapi_path(
170    *, route: routing.APIRoute, model_name_map: Dict[type, str]
171) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
172    path = {}
173    security_schemes: Dict[str, Any] = {}
174    definitions: Dict[str, Any] = {}
175    assert route.methods is not None, "Methods must be a list"
176    if isinstance(route.response_class, DefaultPlaceholder):
177        current_response_class: Type[Response] = route.response_class.value
178    else:
179        current_response_class = route.response_class
180    assert current_response_class, "A response class is needed to generate OpenAPI"
181    route_response_media_type: Optional[str] = current_response_class.media_type
182    if route.include_in_schema:
183        for method in route.methods:
184            operation = get_openapi_operation_metadata(route=route, method=method)
185            parameters: List[Dict[str, Any]] = []
186            flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
187            security_definitions, operation_security = get_openapi_security_definitions(
188                flat_dependant=flat_dependant
189            )
190            if operation_security:
191                operation.setdefault("security", []).extend(operation_security)
192            if security_definitions:
193                security_schemes.update(security_definitions)
194            all_route_params = get_flat_params(route.dependant)
195            operation_parameters = get_openapi_operation_parameters(
196                all_route_params=all_route_params, model_name_map=model_name_map
197            )
198            parameters.extend(operation_parameters)
199            if parameters:
200                operation["parameters"] = list(
201                    {param["name"]: param for param in parameters}.values()
202                )
203            if method in METHODS_WITH_BODY:
204                request_body_oai = get_openapi_operation_request_body(
205                    body_field=route.body_field, model_name_map=model_name_map
206                )
207                if request_body_oai:
208                    operation["requestBody"] = request_body_oai
209            if route.callbacks:
210                callbacks = {}
211                for callback in route.callbacks:
212                    if isinstance(callback, routing.APIRoute):
213                        (
214                            cb_path,
215                            cb_security_schemes,
216                            cb_definitions,
217                        ) = get_openapi_path(
218                            route=callback, model_name_map=model_name_map
219                        )
220                        callbacks[callback.name] = {callback.path: cb_path}
221                operation["callbacks"] = callbacks
222            if route.status_code is not None:
223                status_code = str(route.status_code)
224            else:
225                # It would probably make more sense for all response classes to have an
226                # explicit default status_code, and to extract it from them, instead of
227                # doing this inspection tricks, that would probably be in the future
228                # TODO: probably make status_code a default class attribute for all
229                # responses in Starlette
230                response_signature = inspect.signature(current_response_class.__init__)
231                status_code_param = response_signature.parameters.get("status_code")
232                if status_code_param is not None:
233                    if isinstance(status_code_param.default, int):
234                        status_code = str(status_code_param.default)
235            operation.setdefault("responses", {}).setdefault(status_code, {})[
236                "description"
237            ] = route.response_description
238            if (
239                route_response_media_type
240                and route.status_code not in STATUS_CODES_WITH_NO_BODY
241            ):
242                response_schema = {"type": "string"}
243                if lenient_issubclass(current_response_class, JSONResponse):
244                    if route.response_field:
245                        response_schema, _, _ = field_schema(
246                            route.response_field,
247                            model_name_map=model_name_map,
248                            ref_prefix=REF_PREFIX,
249                        )
250                    else:
251                        response_schema = {}
252                operation.setdefault("responses", {}).setdefault(
253                    status_code, {}
254                ).setdefault("content", {}).setdefault(route_response_media_type, {})[
255                    "schema"
256                ] = response_schema
257            if route.responses:
258                operation_responses = operation.setdefault("responses", {})
259                for (
260                    additional_status_code,
261                    additional_response,
262                ) in route.responses.items():
263                    process_response = additional_response.copy()
264                    process_response.pop("model", None)
265                    status_code_key = str(additional_status_code).upper()
266                    if status_code_key == "DEFAULT":
267                        status_code_key = "default"
268                    openapi_response = operation_responses.setdefault(
269                        status_code_key, {}
270                    )
271                    assert isinstance(
272                        process_response, dict
273                    ), "An additional response must be a dict"
274                    field = route.response_fields.get(additional_status_code)
275                    additional_field_schema: Optional[Dict[str, Any]] = None
276                    if field:
277                        additional_field_schema, _, _ = field_schema(
278                            field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
279                        )
280                        media_type = route_response_media_type or "application/json"
281                        additional_schema = (
282                            process_response.setdefault("content", {})
283                            .setdefault(media_type, {})
284                            .setdefault("schema", {})
285                        )
286                        deep_dict_update(additional_schema, additional_field_schema)
287                    status_text: Optional[str] = status_code_ranges.get(
288                        str(additional_status_code).upper()
289                    ) or http.client.responses.get(int(additional_status_code))
290                    description = (
291                        process_response.get("description")
292                        or openapi_response.get("description")
293                        or status_text
294                        or "Additional Response"
295                    )
296                    deep_dict_update(openapi_response, process_response)
297                    openapi_response["description"] = description
298            http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
299            if (all_route_params or route.body_field) and not any(
300                [
301                    status in operation["responses"]
302                    for status in [http422, "4XX", "default"]
303                ]
304            ):
305                operation["responses"][http422] = {
306                    "description": "Validation Error",
307                    "content": {
308                        "application/json": {
309                            "schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
310                        }
311                    },
312                }
313                if "ValidationError" not in definitions:
314                    definitions.update(
315                        {
316                            "ValidationError": validation_error_definition,
317                            "HTTPValidationError": validation_error_response_definition,
318                        }
319                    )
320            if route.openapi_extra:
321                deep_dict_update(operation, route.openapi_extra)
322            path[method.lower()] = operation
323    return path, security_schemes, definitions
324
325
326def get_flat_models_from_routes(
327    routes: Sequence[BaseRoute],
328) -> Set[Union[Type[BaseModel], Type[Enum]]]:
329    body_fields_from_routes: List[ModelField] = []
330    responses_from_routes: List[ModelField] = []
331    request_fields_from_routes: List[ModelField] = []
332    callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set()
333    for route in routes:
334        if getattr(route, "include_in_schema", None) and isinstance(
335            route, routing.APIRoute
336        ):
337            if route.body_field:
338                assert isinstance(
339                    route.body_field, ModelField
340                ), "A request body must be a Pydantic Field"
341                body_fields_from_routes.append(route.body_field)
342            if route.response_field:
343                responses_from_routes.append(route.response_field)
344            if route.response_fields:
345                responses_from_routes.extend(route.response_fields.values())
346            if route.callbacks:
347                callback_flat_models |= get_flat_models_from_routes(route.callbacks)
348            params = get_flat_params(route.dependant)
349            request_fields_from_routes.extend(params)
350
351    flat_models = callback_flat_models | get_flat_models_from_fields(
352        body_fields_from_routes + responses_from_routes + request_fields_from_routes,
353        known_models=set(),
354    )
355    return flat_models
356
357
358def get_openapi(
359    *,
360    title: str,
361    version: str,
362    openapi_version: str = "3.0.2",
363    description: Optional[str] = None,
364    routes: Sequence[BaseRoute],
365    tags: Optional[List[Dict[str, Any]]] = None,
366    servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
367    terms_of_service: Optional[str] = None,
368    contact: Optional[Dict[str, Union[str, Any]]] = None,
369    license_info: Optional[Dict[str, Union[str, Any]]] = None,
370) -> Dict[str, Any]:
371    info: Dict[str, Any] = {"title": title, "version": version}
372    if description:
373        info["description"] = description
374    if terms_of_service:
375        info["termsOfService"] = terms_of_service
376    if contact:
377        info["contact"] = contact
378    if license_info:
379        info["license"] = license_info
380    output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
381    if servers:
382        output["servers"] = servers
383    components: Dict[str, Dict[str, Any]] = {}
384    paths: Dict[str, Dict[str, Any]] = {}
385    flat_models = get_flat_models_from_routes(routes)
386    model_name_map = get_model_name_map(flat_models)
387    definitions = get_model_definitions(
388        flat_models=flat_models, model_name_map=model_name_map
389    )
390    for route in routes:
391        if isinstance(route, routing.APIRoute):
392            result = get_openapi_path(route=route, model_name_map=model_name_map)
393            if result:
394                path, security_schemes, path_definitions = result
395                if path:
396                    paths.setdefault(route.path_format, {}).update(path)
397                if security_schemes:
398                    components.setdefault("securitySchemes", {}).update(
399                        security_schemes
400                    )
401                if path_definitions:
402                    definitions.update(path_definitions)
403    if definitions:
404        components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
405    if components:
406        output["components"] = components
407    output["paths"] = paths
408    if tags:
409        output["tags"] = tags
410    return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True)  # type: ignore
411