1from samtranslator.model import ResourceMacro, PropertyType
2from samtranslator.model.types import is_type, is_str
3
4from samtranslator.model.lambda_ import LambdaEventSourceMapping
5from samtranslator.translator.arn_generator import ArnGenerator
6from samtranslator.model.exceptions import InvalidEventException
7from samtranslator.model.iam import IAMRolePolicies
8
9
10class PullEventSource(ResourceMacro):
11    """Base class for pull event sources for SAM Functions.
12
13    The pull events are Kinesis Streams, DynamoDB Streams, Kafka Topics, Amazon MQ Queues and SQS Queues. All of these correspond to an
14    EventSourceMapping in Lambda, and require that the execution role be given to Kinesis Streams, DynamoDB
15    Streams, or SQS Queues, respectively.
16
17    :cvar str policy_arn: The ARN of the AWS managed role policy corresponding to this pull event source
18    """
19
20    resource_type = None
21    property_types = {
22        "Stream": PropertyType(False, is_str()),
23        "Queue": PropertyType(False, is_str()),
24        "BatchSize": PropertyType(False, is_type(int)),
25        "StartingPosition": PropertyType(False, is_str()),
26        "Enabled": PropertyType(False, is_type(bool)),
27        "MaximumBatchingWindowInSeconds": PropertyType(False, is_type(int)),
28        "MaximumRetryAttempts": PropertyType(False, is_type(int)),
29        "BisectBatchOnFunctionError": PropertyType(False, is_type(bool)),
30        "MaximumRecordAgeInSeconds": PropertyType(False, is_type(int)),
31        "DestinationConfig": PropertyType(False, is_type(dict)),
32        "ParallelizationFactor": PropertyType(False, is_type(int)),
33        "Topics": PropertyType(False, is_type(list)),
34        "Broker": PropertyType(False, is_str()),
35        "Queues": PropertyType(False, is_type(list)),
36        "SourceAccessConfigurations": PropertyType(False, is_type(list)),
37        "SecretsManagerKmsKeyId": PropertyType(False, is_str()),
38        "TumblingWindowInSeconds": PropertyType(False, is_type(int)),
39        "FunctionResponseTypes": PropertyType(False, is_type(list)),
40    }
41
42    def get_policy_arn(self):
43        raise NotImplementedError("Subclass must implement this method")
44
45    def get_policy_statements(self):
46        raise NotImplementedError("Subclass must implement this method")
47
48    def to_cloudformation(self, **kwargs):
49        """Returns the Lambda EventSourceMapping to which this pull event corresponds. Adds the appropriate managed
50        policy to the function's execution role, if such a role is provided.
51
52        :param dict kwargs: a dict containing the execution role generated for the function
53        :returns: a list of vanilla CloudFormation Resources, to which this pull event expands
54        :rtype: list
55        """
56        function = kwargs.get("function")
57
58        if not function:
59            raise TypeError("Missing required keyword argument: function")
60
61        resources = []
62
63        lambda_eventsourcemapping = LambdaEventSourceMapping(
64            self.logical_id, attributes=function.get_passthrough_resource_attributes()
65        )
66        resources.append(lambda_eventsourcemapping)
67
68        try:
69            # Name will not be available for Alias resources
70            function_name_or_arn = function.get_runtime_attr("name")
71        except NotImplementedError:
72            function_name_or_arn = function.get_runtime_attr("arn")
73
74        if not self.Stream and not self.Queue and not self.Broker:
75            raise InvalidEventException(
76                self.relative_id,
77                "No Queue (for SQS) or Stream (for Kinesis, DynamoDB or MSK) or Broker (for Amazon MQ) provided.",
78            )
79
80        if self.Stream and not self.StartingPosition:
81            raise InvalidEventException(self.relative_id, "StartingPosition is required for Kinesis, DynamoDB and MSK.")
82
83        lambda_eventsourcemapping.FunctionName = function_name_or_arn
84        lambda_eventsourcemapping.EventSourceArn = self.Stream or self.Queue or self.Broker
85        lambda_eventsourcemapping.StartingPosition = self.StartingPosition
86        lambda_eventsourcemapping.BatchSize = self.BatchSize
87        lambda_eventsourcemapping.Enabled = self.Enabled
88        lambda_eventsourcemapping.MaximumBatchingWindowInSeconds = self.MaximumBatchingWindowInSeconds
89        lambda_eventsourcemapping.MaximumRetryAttempts = self.MaximumRetryAttempts
90        lambda_eventsourcemapping.BisectBatchOnFunctionError = self.BisectBatchOnFunctionError
91        lambda_eventsourcemapping.MaximumRecordAgeInSeconds = self.MaximumRecordAgeInSeconds
92        lambda_eventsourcemapping.ParallelizationFactor = self.ParallelizationFactor
93        lambda_eventsourcemapping.Topics = self.Topics
94        lambda_eventsourcemapping.Queues = self.Queues
95        lambda_eventsourcemapping.SourceAccessConfigurations = self.SourceAccessConfigurations
96        lambda_eventsourcemapping.TumblingWindowInSeconds = self.TumblingWindowInSeconds
97        lambda_eventsourcemapping.FunctionResponseTypes = self.FunctionResponseTypes
98
99        destination_config_policy = None
100        if self.DestinationConfig:
101            # `Type` property is for sam to attach the right policies
102            destination_type = self.DestinationConfig.get("OnFailure").get("Type")
103
104            # SAM attaches the policies for SQS or SNS only if 'Type' is given
105            if destination_type:
106                # delete this field as its used internally for SAM to determine the policy
107                del self.DestinationConfig["OnFailure"]["Type"]
108                # the values 'SQS' and 'SNS' are allowed. No intrinsics are allowed
109                if destination_type not in ["SQS", "SNS"]:
110                    raise InvalidEventException(self.logical_id, "The only valid values for 'Type' are 'SQS' and 'SNS'")
111                if self.DestinationConfig.get("OnFailure") is None:
112                    raise InvalidEventException(
113                        self.logical_id, "'OnFailure' is a required field for " "'DestinationConfig'"
114                    )
115                if destination_type == "SQS":
116                    queue_arn = self.DestinationConfig.get("OnFailure").get("Destination")
117                    destination_config_policy = IAMRolePolicies().sqs_send_message_role_policy(
118                        queue_arn, self.logical_id
119                    )
120                elif destination_type == "SNS":
121                    sns_topic_arn = self.DestinationConfig.get("OnFailure").get("Destination")
122                    destination_config_policy = IAMRolePolicies().sns_publish_role_policy(
123                        sns_topic_arn, self.logical_id
124                    )
125            lambda_eventsourcemapping.DestinationConfig = self.DestinationConfig
126
127        if "role" in kwargs:
128            self._link_policy(kwargs["role"], destination_config_policy)
129
130        return resources
131
132    def _link_policy(self, role, destination_config_policy=None):
133        """If this source triggers a Lambda function whose execution role is auto-generated by SAM, add the
134        appropriate managed policy to this Role.
135
136        :param model.iam.IAMRole role: the execution role generated for the function
137        """
138        policy_arn = self.get_policy_arn()
139        policy_statements = self.get_policy_statements()
140        if role is not None:
141            if policy_arn is not None and policy_arn not in role.ManagedPolicyArns:
142                role.ManagedPolicyArns.append(policy_arn)
143            if policy_statements is not None:
144                if role.Policies is None:
145                    role.Policies = []
146                for policy in policy_statements:
147                    if policy not in role.Policies:
148                        if not policy.get("PolicyDocument") in [d["PolicyDocument"] for d in role.Policies]:
149                            role.Policies.append(policy)
150        # add SQS or SNS policy only if role is present in kwargs
151        if role is not None and destination_config_policy is not None and destination_config_policy:
152            if role.Policies is None:
153                role.Policies = []
154                role.Policies.append(destination_config_policy)
155            if role.Policies and destination_config_policy not in role.Policies:
156                # do not add the  policy if the same policy document is already present
157                if not destination_config_policy.get("PolicyDocument") in [d["PolicyDocument"] for d in role.Policies]:
158                    role.Policies.append(destination_config_policy)
159
160
161class Kinesis(PullEventSource):
162    """Kinesis event source."""
163
164    resource_type = "Kinesis"
165
166    def get_policy_arn(self):
167        return ArnGenerator.generate_aws_managed_policy_arn("service-role/AWSLambdaKinesisExecutionRole")
168
169    def get_policy_statements(self):
170        return None
171
172
173class DynamoDB(PullEventSource):
174    """DynamoDB Streams event source."""
175
176    resource_type = "DynamoDB"
177
178    def get_policy_arn(self):
179        return ArnGenerator.generate_aws_managed_policy_arn("service-role/AWSLambdaDynamoDBExecutionRole")
180
181    def get_policy_statements(self):
182        return None
183
184
185class SQS(PullEventSource):
186    """SQS Queue event source."""
187
188    resource_type = "SQS"
189
190    def get_policy_arn(self):
191        return ArnGenerator.generate_aws_managed_policy_arn("service-role/AWSLambdaSQSQueueExecutionRole")
192
193    def get_policy_statements(self):
194        return None
195
196
197class MSK(PullEventSource):
198    """MSK event source."""
199
200    resource_type = "MSK"
201
202    def get_policy_arn(self):
203        return ArnGenerator.generate_aws_managed_policy_arn("service-role/AWSLambdaMSKExecutionRole")
204
205    def get_policy_statements(self):
206        return None
207
208
209class MQ(PullEventSource):
210    """MQ event source."""
211
212    resource_type = "MQ"
213
214    def get_policy_arn(self):
215        return None
216
217    def get_policy_statements(self):
218        if not self.SourceAccessConfigurations:
219            raise InvalidEventException(
220                self.relative_id,
221                "No SourceAccessConfigurations for Amazon MQ event provided.",
222            )
223        if not type(self.SourceAccessConfigurations) is list:
224            raise InvalidEventException(
225                self.relative_id,
226                "Provided SourceAccessConfigurations cannot be parsed into a list.",
227            )
228        basic_auth_uri = None
229        for conf in self.SourceAccessConfigurations:
230            event_type = conf.get("Type")
231            if event_type not in ("BASIC_AUTH", "VIRTUAL_HOST"):
232                raise InvalidEventException(
233                    self.relative_id,
234                    "Invalid property specified in SourceAccessConfigurations for Amazon MQ event.",
235                )
236            if event_type == "BASIC_AUTH":
237                if basic_auth_uri:
238                    raise InvalidEventException(
239                        self.relative_id,
240                        "Multiple BASIC_AUTH properties specified in SourceAccessConfigurations for Amazon MQ event.",
241                    )
242                basic_auth_uri = conf.get("URI")
243                if not basic_auth_uri:
244                    raise InvalidEventException(
245                        self.relative_id,
246                        "No BASIC_AUTH URI property specified in SourceAccessConfigurations for Amazon MQ event.",
247                    )
248
249        if not basic_auth_uri:
250            raise InvalidEventException(
251                self.relative_id,
252                "No BASIC_AUTH property specified in SourceAccessConfigurations for Amazon MQ event.",
253            )
254        document = {
255            "PolicyName": "SamAutoGeneratedAMQPolicy",
256            "PolicyDocument": {
257                "Statement": [
258                    {
259                        "Action": [
260                            "secretsmanager:GetSecretValue",
261                        ],
262                        "Effect": "Allow",
263                        "Resource": basic_auth_uri,
264                    },
265                    {
266                        "Action": [
267                            "mq:DescribeBroker",
268                        ],
269                        "Effect": "Allow",
270                        "Resource": self.Broker,
271                    },
272                ]
273            },
274        }
275        if self.SecretsManagerKmsKeyId:
276            kms_policy = {
277                "Action": "kms:Decrypt",
278                "Effect": "Allow",
279                "Resource": {
280                    "Fn::Sub": "arn:${AWS::Partition}:kms:${AWS::Region}:${AWS::AccountId}:key/"
281                    + self.SecretsManagerKmsKeyId
282                },
283            }
284            document["PolicyDocument"]["Statement"].append(kms_policy)
285        return [document]
286