1import http.client 2import inspect 3from enum import Enum 4from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast 5 6from fastapi import routing 7from fastapi.datastructures import DefaultPlaceholder 8from fastapi.dependencies.models import Dependant 9from fastapi.dependencies.utils import get_flat_dependant, get_flat_params 10from fastapi.encoders import jsonable_encoder 11from fastapi.openapi.constants import ( 12 METHODS_WITH_BODY, 13 REF_PREFIX, 14 STATUS_CODES_WITH_NO_BODY, 15) 16from fastapi.openapi.models import OpenAPI 17from fastapi.params import Body, Param 18from fastapi.responses import Response 19from fastapi.utils import ( 20 deep_dict_update, 21 generate_operation_id_for_path, 22 get_model_definitions, 23) 24from pydantic import BaseModel 25from pydantic.fields import ModelField, Undefined 26from pydantic.schema import ( 27 field_schema, 28 get_flat_models_from_fields, 29 get_model_name_map, 30) 31from pydantic.utils import lenient_issubclass 32from starlette.responses import JSONResponse 33from starlette.routing import BaseRoute 34from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY 35 36validation_error_definition = { 37 "title": "ValidationError", 38 "type": "object", 39 "properties": { 40 "loc": {"title": "Location", "type": "array", "items": {"type": "string"}}, 41 "msg": {"title": "Message", "type": "string"}, 42 "type": {"title": "Error Type", "type": "string"}, 43 }, 44 "required": ["loc", "msg", "type"], 45} 46 47validation_error_response_definition = { 48 "title": "HTTPValidationError", 49 "type": "object", 50 "properties": { 51 "detail": { 52 "title": "Detail", 53 "type": "array", 54 "items": {"$ref": REF_PREFIX + "ValidationError"}, 55 } 56 }, 57} 58 59status_code_ranges: Dict[str, str] = { 60 "1XX": "Information", 61 "2XX": "Success", 62 "3XX": "Redirection", 63 "4XX": "Client Error", 64 "5XX": "Server Error", 65 "DEFAULT": "Default Response", 66} 67 68 69def get_openapi_security_definitions( 70 flat_dependant: Dependant, 71) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: 72 security_definitions = {} 73 operation_security = [] 74 for security_requirement in flat_dependant.security_requirements: 75 security_definition = jsonable_encoder( 76 security_requirement.security_scheme.model, 77 by_alias=True, 78 exclude_none=True, 79 ) 80 security_name = security_requirement.security_scheme.scheme_name 81 security_definitions[security_name] = security_definition 82 operation_security.append({security_name: security_requirement.scopes}) 83 return security_definitions, operation_security 84 85 86def get_openapi_operation_parameters( 87 *, 88 all_route_params: Sequence[ModelField], 89 model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], 90) -> List[Dict[str, Any]]: 91 parameters = [] 92 for param in all_route_params: 93 field_info = param.field_info 94 field_info = cast(Param, field_info) 95 parameter = { 96 "name": param.alias, 97 "in": field_info.in_.value, 98 "required": param.required, 99 "schema": field_schema( 100 param, model_name_map=model_name_map, ref_prefix=REF_PREFIX 101 )[0], 102 } 103 if field_info.description: 104 parameter["description"] = field_info.description 105 if field_info.examples: 106 parameter["examples"] = jsonable_encoder(field_info.examples) 107 elif field_info.example != Undefined: 108 parameter["example"] = jsonable_encoder(field_info.example) 109 if field_info.deprecated: 110 parameter["deprecated"] = field_info.deprecated 111 parameters.append(parameter) 112 return parameters 113 114 115def get_openapi_operation_request_body( 116 *, 117 body_field: Optional[ModelField], 118 model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str], 119) -> Optional[Dict[str, Any]]: 120 if not body_field: 121 return None 122 assert isinstance(body_field, ModelField) 123 body_schema, _, _ = field_schema( 124 body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX 125 ) 126 field_info = cast(Body, body_field.field_info) 127 request_media_type = field_info.media_type 128 required = body_field.required 129 request_body_oai: Dict[str, Any] = {} 130 if required: 131 request_body_oai["required"] = required 132 request_media_content: Dict[str, Any] = {"schema": body_schema} 133 if field_info.examples: 134 request_media_content["examples"] = jsonable_encoder(field_info.examples) 135 elif field_info.example != Undefined: 136 request_media_content["example"] = jsonable_encoder(field_info.example) 137 request_body_oai["content"] = {request_media_type: request_media_content} 138 return request_body_oai 139 140 141def generate_operation_id(*, route: routing.APIRoute, method: str) -> str: 142 if route.operation_id: 143 return route.operation_id 144 path: str = route.path_format 145 return generate_operation_id_for_path(name=route.name, path=path, method=method) 146 147 148def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str: 149 if route.summary: 150 return route.summary 151 return route.name.replace("_", " ").title() 152 153 154def get_openapi_operation_metadata( 155 *, route: routing.APIRoute, method: str 156) -> Dict[str, Any]: 157 operation: Dict[str, Any] = {} 158 if route.tags: 159 operation["tags"] = route.tags 160 operation["summary"] = generate_operation_summary(route=route, method=method) 161 if route.description: 162 operation["description"] = route.description 163 operation["operationId"] = generate_operation_id(route=route, method=method) 164 if route.deprecated: 165 operation["deprecated"] = route.deprecated 166 return operation 167 168 169def get_openapi_path( 170 *, route: routing.APIRoute, model_name_map: Dict[type, str] 171) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: 172 path = {} 173 security_schemes: Dict[str, Any] = {} 174 definitions: Dict[str, Any] = {} 175 assert route.methods is not None, "Methods must be a list" 176 if isinstance(route.response_class, DefaultPlaceholder): 177 current_response_class: Type[Response] = route.response_class.value 178 else: 179 current_response_class = route.response_class 180 assert current_response_class, "A response class is needed to generate OpenAPI" 181 route_response_media_type: Optional[str] = current_response_class.media_type 182 if route.include_in_schema: 183 for method in route.methods: 184 operation = get_openapi_operation_metadata(route=route, method=method) 185 parameters: List[Dict[str, Any]] = [] 186 flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True) 187 security_definitions, operation_security = get_openapi_security_definitions( 188 flat_dependant=flat_dependant 189 ) 190 if operation_security: 191 operation.setdefault("security", []).extend(operation_security) 192 if security_definitions: 193 security_schemes.update(security_definitions) 194 all_route_params = get_flat_params(route.dependant) 195 operation_parameters = get_openapi_operation_parameters( 196 all_route_params=all_route_params, model_name_map=model_name_map 197 ) 198 parameters.extend(operation_parameters) 199 if parameters: 200 operation["parameters"] = list( 201 {param["name"]: param for param in parameters}.values() 202 ) 203 if method in METHODS_WITH_BODY: 204 request_body_oai = get_openapi_operation_request_body( 205 body_field=route.body_field, model_name_map=model_name_map 206 ) 207 if request_body_oai: 208 operation["requestBody"] = request_body_oai 209 if route.callbacks: 210 callbacks = {} 211 for callback in route.callbacks: 212 if isinstance(callback, routing.APIRoute): 213 ( 214 cb_path, 215 cb_security_schemes, 216 cb_definitions, 217 ) = get_openapi_path( 218 route=callback, model_name_map=model_name_map 219 ) 220 callbacks[callback.name] = {callback.path: cb_path} 221 operation["callbacks"] = callbacks 222 if route.status_code is not None: 223 status_code = str(route.status_code) 224 else: 225 # It would probably make more sense for all response classes to have an 226 # explicit default status_code, and to extract it from them, instead of 227 # doing this inspection tricks, that would probably be in the future 228 # TODO: probably make status_code a default class attribute for all 229 # responses in Starlette 230 response_signature = inspect.signature(current_response_class.__init__) 231 status_code_param = response_signature.parameters.get("status_code") 232 if status_code_param is not None: 233 if isinstance(status_code_param.default, int): 234 status_code = str(status_code_param.default) 235 operation.setdefault("responses", {}).setdefault(status_code, {})[ 236 "description" 237 ] = route.response_description 238 if ( 239 route_response_media_type 240 and route.status_code not in STATUS_CODES_WITH_NO_BODY 241 ): 242 response_schema = {"type": "string"} 243 if lenient_issubclass(current_response_class, JSONResponse): 244 if route.response_field: 245 response_schema, _, _ = field_schema( 246 route.response_field, 247 model_name_map=model_name_map, 248 ref_prefix=REF_PREFIX, 249 ) 250 else: 251 response_schema = {} 252 operation.setdefault("responses", {}).setdefault( 253 status_code, {} 254 ).setdefault("content", {}).setdefault(route_response_media_type, {})[ 255 "schema" 256 ] = response_schema 257 if route.responses: 258 operation_responses = operation.setdefault("responses", {}) 259 for ( 260 additional_status_code, 261 additional_response, 262 ) in route.responses.items(): 263 process_response = additional_response.copy() 264 process_response.pop("model", None) 265 status_code_key = str(additional_status_code).upper() 266 if status_code_key == "DEFAULT": 267 status_code_key = "default" 268 openapi_response = operation_responses.setdefault( 269 status_code_key, {} 270 ) 271 assert isinstance( 272 process_response, dict 273 ), "An additional response must be a dict" 274 field = route.response_fields.get(additional_status_code) 275 additional_field_schema: Optional[Dict[str, Any]] = None 276 if field: 277 additional_field_schema, _, _ = field_schema( 278 field, model_name_map=model_name_map, ref_prefix=REF_PREFIX 279 ) 280 media_type = route_response_media_type or "application/json" 281 additional_schema = ( 282 process_response.setdefault("content", {}) 283 .setdefault(media_type, {}) 284 .setdefault("schema", {}) 285 ) 286 deep_dict_update(additional_schema, additional_field_schema) 287 status_text: Optional[str] = status_code_ranges.get( 288 str(additional_status_code).upper() 289 ) or http.client.responses.get(int(additional_status_code)) 290 description = ( 291 process_response.get("description") 292 or openapi_response.get("description") 293 or status_text 294 or "Additional Response" 295 ) 296 deep_dict_update(openapi_response, process_response) 297 openapi_response["description"] = description 298 http422 = str(HTTP_422_UNPROCESSABLE_ENTITY) 299 if (all_route_params or route.body_field) and not any( 300 [ 301 status in operation["responses"] 302 for status in [http422, "4XX", "default"] 303 ] 304 ): 305 operation["responses"][http422] = { 306 "description": "Validation Error", 307 "content": { 308 "application/json": { 309 "schema": {"$ref": REF_PREFIX + "HTTPValidationError"} 310 } 311 }, 312 } 313 if "ValidationError" not in definitions: 314 definitions.update( 315 { 316 "ValidationError": validation_error_definition, 317 "HTTPValidationError": validation_error_response_definition, 318 } 319 ) 320 if route.openapi_extra: 321 deep_dict_update(operation, route.openapi_extra) 322 path[method.lower()] = operation 323 return path, security_schemes, definitions 324 325 326def get_flat_models_from_routes( 327 routes: Sequence[BaseRoute], 328) -> Set[Union[Type[BaseModel], Type[Enum]]]: 329 body_fields_from_routes: List[ModelField] = [] 330 responses_from_routes: List[ModelField] = [] 331 request_fields_from_routes: List[ModelField] = [] 332 callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set() 333 for route in routes: 334 if getattr(route, "include_in_schema", None) and isinstance( 335 route, routing.APIRoute 336 ): 337 if route.body_field: 338 assert isinstance( 339 route.body_field, ModelField 340 ), "A request body must be a Pydantic Field" 341 body_fields_from_routes.append(route.body_field) 342 if route.response_field: 343 responses_from_routes.append(route.response_field) 344 if route.response_fields: 345 responses_from_routes.extend(route.response_fields.values()) 346 if route.callbacks: 347 callback_flat_models |= get_flat_models_from_routes(route.callbacks) 348 params = get_flat_params(route.dependant) 349 request_fields_from_routes.extend(params) 350 351 flat_models = callback_flat_models | get_flat_models_from_fields( 352 body_fields_from_routes + responses_from_routes + request_fields_from_routes, 353 known_models=set(), 354 ) 355 return flat_models 356 357 358def get_openapi( 359 *, 360 title: str, 361 version: str, 362 openapi_version: str = "3.0.2", 363 description: Optional[str] = None, 364 routes: Sequence[BaseRoute], 365 tags: Optional[List[Dict[str, Any]]] = None, 366 servers: Optional[List[Dict[str, Union[str, Any]]]] = None, 367 terms_of_service: Optional[str] = None, 368 contact: Optional[Dict[str, Union[str, Any]]] = None, 369 license_info: Optional[Dict[str, Union[str, Any]]] = None, 370) -> Dict[str, Any]: 371 info: Dict[str, Any] = {"title": title, "version": version} 372 if description: 373 info["description"] = description 374 if terms_of_service: 375 info["termsOfService"] = terms_of_service 376 if contact: 377 info["contact"] = contact 378 if license_info: 379 info["license"] = license_info 380 output: Dict[str, Any] = {"openapi": openapi_version, "info": info} 381 if servers: 382 output["servers"] = servers 383 components: Dict[str, Dict[str, Any]] = {} 384 paths: Dict[str, Dict[str, Any]] = {} 385 flat_models = get_flat_models_from_routes(routes) 386 model_name_map = get_model_name_map(flat_models) 387 definitions = get_model_definitions( 388 flat_models=flat_models, model_name_map=model_name_map 389 ) 390 for route in routes: 391 if isinstance(route, routing.APIRoute): 392 result = get_openapi_path(route=route, model_name_map=model_name_map) 393 if result: 394 path, security_schemes, path_definitions = result 395 if path: 396 paths.setdefault(route.path_format, {}).update(path) 397 if security_schemes: 398 components.setdefault("securitySchemes", {}).update( 399 security_schemes 400 ) 401 if path_definitions: 402 definitions.update(path_definitions) 403 if definitions: 404 components["schemas"] = {k: definitions[k] for k in sorted(definitions)} 405 if components: 406 output["components"] = components 407 output["paths"] = paths 408 if tags: 409 output["tags"] = tags 410 return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore 411