1from samtranslator.model import PropertyType, Resource
2from samtranslator.model.types import is_type, one_of, is_str, list_of
3from samtranslator.model.intrinsics import ref, fnSub
4from samtranslator.model.exceptions import InvalidResourceException
5from samtranslator.translator.arn_generator import ArnGenerator
6
7APIGATEWAY_AUTHORIZER_KEY = "x-amazon-apigateway-authorizer"
8
9
10class ApiGatewayV2HttpApi(Resource):
11    resource_type = "AWS::ApiGatewayV2::Api"
12    property_types = {
13        "Body": PropertyType(False, is_type(dict)),
14        "BodyS3Location": PropertyType(False, is_type(dict)),
15        "Description": PropertyType(False, is_str()),
16        "FailOnWarnings": PropertyType(False, is_type(bool)),
17        "DisableExecuteApiEndpoint": PropertyType(False, is_type(bool)),
18        "BasePath": PropertyType(False, is_str()),
19        "CorsConfiguration": PropertyType(False, is_type(dict)),
20    }
21
22    runtime_attrs = {"http_api_id": lambda self: ref(self.logical_id)}
23
24
25class ApiGatewayV2Stage(Resource):
26    resource_type = "AWS::ApiGatewayV2::Stage"
27    property_types = {
28        "AccessLogSettings": PropertyType(False, is_type(dict)),
29        "DefaultRouteSettings": PropertyType(False, is_type(dict)),
30        "RouteSettings": PropertyType(False, is_type(dict)),
31        "ClientCertificateId": PropertyType(False, is_str()),
32        "Description": PropertyType(False, is_str()),
33        "ApiId": PropertyType(True, is_str()),
34        "StageName": PropertyType(False, one_of(is_str(), is_type(dict))),
35        "Tags": PropertyType(False, is_type(dict)),
36        "StageVariables": PropertyType(False, is_type(dict)),
37        "AutoDeploy": PropertyType(False, is_type(bool)),
38    }
39
40    runtime_attrs = {"stage_name": lambda self: ref(self.logical_id)}
41
42
43class ApiGatewayV2DomainName(Resource):
44    resource_type = "AWS::ApiGatewayV2::DomainName"
45    property_types = {
46        "DomainName": PropertyType(True, is_str()),
47        "DomainNameConfigurations": PropertyType(False, list_of(is_type(dict))),
48        "MutualTlsAuthentication": PropertyType(False, is_type(dict)),
49        "Tags": PropertyType(False, is_type(dict)),
50    }
51
52
53class ApiGatewayV2ApiMapping(Resource):
54    resource_type = "AWS::ApiGatewayV2::ApiMapping"
55    property_types = {
56        "ApiId": PropertyType(True, is_str()),
57        "ApiMappingKey": PropertyType(False, is_str()),
58        "DomainName": PropertyType(True, is_str()),
59        "Stage": PropertyType(True, is_str()),
60    }
61
62
63class ApiGatewayV2Authorizer(object):
64    def __init__(
65        self,
66        api_logical_id=None,
67        name=None,
68        authorization_scopes=None,
69        jwt_configuration=None,
70        id_source=None,
71        function_arn=None,
72        function_invoke_role=None,
73        identity=None,
74        authorizer_payload_format_version=None,
75        enable_simple_responses=None,
76    ):
77        """
78        Creates an authorizer for use in V2 Http Apis
79        """
80        self.api_logical_id = api_logical_id
81        self.name = name
82        self.authorization_scopes = authorization_scopes
83        self.jwt_configuration = jwt_configuration
84        self.id_source = id_source
85        self.function_arn = function_arn
86        self.function_invoke_role = function_invoke_role
87        self.identity = identity
88        self.authorizer_payload_format_version = authorizer_payload_format_version
89        self.enable_simple_responses = enable_simple_responses
90
91        self._validate_input_parameters()
92
93        authorizer_type = self._get_auth_type()
94
95        # Validate necessary parameters exist
96        if authorizer_type == "JWT":
97            self._validate_jwt_authorizer()
98
99        if authorizer_type == "REQUEST":
100            self._validate_lambda_authorizer()
101
102    def _get_auth_type(self):
103        if self.jwt_configuration:
104            return "JWT"
105        return "REQUEST"
106
107    def _validate_input_parameters(self):
108        authorizer_type = self._get_auth_type()
109
110        if self.authorization_scopes is not None and not isinstance(self.authorization_scopes, list):
111            raise InvalidResourceException(self.api_logical_id, "AuthorizationScopes must be a list.")
112
113        if self.authorization_scopes is not None and not authorizer_type == "JWT":
114            raise InvalidResourceException(
115                self.api_logical_id, "AuthorizationScopes must be defined only for OAuth2 Authorizer."
116            )
117
118        if self.jwt_configuration is not None and not authorizer_type == "JWT":
119            raise InvalidResourceException(
120                self.api_logical_id, "JwtConfiguration must be defined only for OAuth2 Authorizer."
121            )
122
123        if self.id_source is not None and not authorizer_type == "JWT":
124            raise InvalidResourceException(
125                self.api_logical_id, "IdentitySource must be defined only for OAuth2 Authorizer."
126            )
127
128        if self.function_arn is not None and not authorizer_type == "REQUEST":
129            raise InvalidResourceException(
130                self.api_logical_id, "FunctionArn must be defined only for Lambda Authorizer."
131            )
132
133        if self.function_invoke_role is not None and not authorizer_type == "REQUEST":
134            raise InvalidResourceException(
135                self.api_logical_id, "FunctionInvokeRole must be defined only for Lambda Authorizer."
136            )
137
138        if self.identity is not None and not authorizer_type == "REQUEST":
139            raise InvalidResourceException(self.api_logical_id, "Identity must be defined only for Lambda Authorizer.")
140
141        if self.authorizer_payload_format_version is not None and not authorizer_type == "REQUEST":
142            raise InvalidResourceException(
143                self.api_logical_id, "AuthorizerPayloadFormatVersion must be defined only for Lambda Authorizer."
144            )
145
146        if self.enable_simple_responses is not None and not authorizer_type == "REQUEST":
147            raise InvalidResourceException(
148                self.api_logical_id, "EnableSimpleResponses must be defined only for Lambda Authorizer."
149            )
150
151    def _validate_jwt_authorizer(self):
152        if not self.jwt_configuration:
153            raise InvalidResourceException(
154                self.api_logical_id, self.name + " OAuth2 Authorizer must define 'JwtConfiguration'."
155            )
156        if not self.id_source:
157            raise InvalidResourceException(
158                self.api_logical_id, self.name + " OAuth2 Authorizer must define 'IdentitySource'."
159            )
160
161    def _validate_lambda_authorizer(self):
162        if not self.function_arn:
163            raise InvalidResourceException(
164                self.api_logical_id, self.name + " Lambda Authorizer must define 'FunctionArn'."
165            )
166        if not self.authorizer_payload_format_version:
167            raise InvalidResourceException(
168                self.api_logical_id, self.name + " Lambda Authorizer must define 'AuthorizerPayloadFormatVersion'."
169            )
170
171    def generate_openapi(self):
172        """
173        Generates OAS for the securitySchemes section
174        """
175        authorizer_type = self._get_auth_type()
176
177        if authorizer_type == "JWT":
178            openapi = {"type": "oauth2"}
179            openapi[APIGATEWAY_AUTHORIZER_KEY] = {
180                "jwtConfiguration": self.jwt_configuration,
181                "identitySource": self.id_source,
182                "type": "jwt",
183            }
184
185        if authorizer_type == "REQUEST":
186            openapi = {
187                "type": "apiKey",
188                "name": "Unused",
189                "in": "header",
190            }
191            openapi[APIGATEWAY_AUTHORIZER_KEY] = {"type": "request"}
192
193            # Generate the lambda arn
194            partition = ArnGenerator.get_partition_name()
195            resource = "lambda:path/2015-03-31/functions/${__FunctionArn__}/invocations"
196            authorizer_uri = fnSub(
197                ArnGenerator.generate_arn(
198                    partition=partition, service="apigateway", resource=resource, include_account_id=False
199                ),
200                {"__FunctionArn__": self.function_arn},
201            )
202            openapi[APIGATEWAY_AUTHORIZER_KEY]["authorizerUri"] = authorizer_uri
203
204            # Set authorizerCredentials if present
205            function_invoke_role = self._get_function_invoke_role()
206            if function_invoke_role:
207                openapi[APIGATEWAY_AUTHORIZER_KEY]["authorizerCredentials"] = function_invoke_role
208
209            # Set authorizerResultTtlInSeconds if present
210            reauthorize_every = self._get_reauthorize_every()
211            if reauthorize_every is not None:
212                openapi[APIGATEWAY_AUTHORIZER_KEY]["authorizerResultTtlInSeconds"] = reauthorize_every
213
214            # Set identitySource if present
215            if self.identity:
216                openapi[APIGATEWAY_AUTHORIZER_KEY]["identitySource"] = self._get_identity_source()
217
218            # Set authorizerPayloadFormatVersion. It's a required parameter
219            openapi[APIGATEWAY_AUTHORIZER_KEY][
220                "authorizerPayloadFormatVersion"
221            ] = self.authorizer_payload_format_version
222
223            # Set authorizerPayloadFormatVersion. It's a required parameter
224            if self.enable_simple_responses:
225                openapi[APIGATEWAY_AUTHORIZER_KEY]["enableSimpleResponses"] = self.enable_simple_responses
226
227        return openapi
228
229    def _get_function_invoke_role(self):
230        if not self.function_invoke_role or self.function_invoke_role == "NONE":
231            return None
232
233        return self.function_invoke_role
234
235    def _get_identity_source(self):
236        identity_source_headers = []
237        identity_source_query_strings = []
238        identity_source_stage_variables = []
239        identity_source_context = []
240
241        if self.identity.get("Headers"):
242            identity_source_headers = list(map(lambda h: "$request.header." + h, self.identity.get("Headers")))
243
244        if self.identity.get("QueryStrings"):
245            identity_source_query_strings = list(
246                map(lambda qs: "$request.querystring." + qs, self.identity.get("QueryStrings"))
247            )
248
249        if self.identity.get("StageVariables"):
250            identity_source_stage_variables = list(
251                map(lambda sv: "$stageVariables." + sv, self.identity.get("StageVariables"))
252            )
253
254        if self.identity.get("Context"):
255            identity_source_context = list(map(lambda c: "$context." + c, self.identity.get("Context")))
256
257        identity_source = (
258            identity_source_headers
259            + identity_source_query_strings
260            + identity_source_stage_variables
261            + identity_source_context
262        )
263
264        return identity_source
265
266    def _get_reauthorize_every(self):
267        if not self.identity:
268            return None
269
270        return self.identity.get("ReauthorizeEvery")
271