1import base64
2import hashlib
3import json
4import random
5import re
6import string
7
8import struct
9from copy import deepcopy
10from typing import Dict
11from xml.sax.saxutils import escape
12
13from boto3 import Session
14
15from moto.core.exceptions import RESTError
16from moto.core import BaseBackend, BaseModel, CloudFormationModel
17from moto.core.utils import (
18    camelcase_to_underscores,
19    get_random_message_id,
20    unix_time,
21    unix_time_millis,
22    tags_from_cloudformation_tags_list,
23)
24from .utils import generate_receipt_handle
25from .exceptions import (
26    MessageAttributesInvalid,
27    MessageNotInflight,
28    QueueDoesNotExist,
29    QueueAlreadyExists,
30    ReceiptHandleIsInvalid,
31    InvalidBatchEntryId,
32    BatchRequestTooLong,
33    BatchEntryIdsNotDistinct,
34    TooManyEntriesInBatchRequest,
35    InvalidAttributeName,
36    InvalidParameterValue,
37    MissingParameter,
38    OverLimit,
39    InvalidAttributeValue,
40)
41
42from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
43
44DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU"
45
46MAXIMUM_MESSAGE_LENGTH = 262144  # 256 KiB
47
48MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND = 1024
49MAXIMUM_MESSAGE_SIZE_ATTR_UPPER_BOUND = MAXIMUM_MESSAGE_LENGTH
50
51TRANSPORT_TYPE_ENCODINGS = {
52    "String": b"\x01",
53    "Binary": b"\x02",
54    "Number": b"\x01",
55    "String.custom": b"\x01",
56}
57
58STRING_TYPE_FIELD_INDEX = 1
59BINARY_TYPE_FIELD_INDEX = 2
60STRING_LIST_TYPE_FIELD_INDEX = 3
61BINARY_LIST_TYPE_FIELD_INDEX = 4
62
63# Valid attribute name rules can found at
64# https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html
65ATTRIBUTE_NAME_PATTERN = re.compile("^([a-z]|[A-Z]|[0-9]|[_.\\-])+$")
66
67DEDUPLICATION_TIME_IN_SECONDS = 300
68
69
70class Message(BaseModel):
71    def __init__(self, message_id, body, system_attributes={}):
72        self.id = message_id
73        self._body = body
74        self.message_attributes = {}
75        self.receipt_handle = None
76        self.sender_id = DEFAULT_SENDER_ID
77        self.sent_timestamp = None
78        self.approximate_first_receive_timestamp = None
79        self.approximate_receive_count = 0
80        self.deduplication_id = None
81        self.group_id = None
82        self.sequence_number = None
83        self.visible_at = 0
84        self.delayed_until = 0
85        self.system_attributes = system_attributes
86
87    @property
88    def body_md5(self):
89        md5 = hashlib.md5()
90        md5.update(self._body.encode("utf-8"))
91        return md5.hexdigest()
92
93    @property
94    def attribute_md5(self):
95
96        md5 = hashlib.md5()
97
98        for attrName in sorted(self.message_attributes.keys()):
99            self.validate_attribute_name(attrName)
100            attrValue = self.message_attributes[attrName]
101            # Encode name
102            self.update_binary_length_and_value(md5, self.utf8(attrName))
103            # Encode type
104            self.update_binary_length_and_value(md5, self.utf8(attrValue["data_type"]))
105
106            if attrValue.get("string_value"):
107                md5.update(bytearray([STRING_TYPE_FIELD_INDEX]))
108                self.update_binary_length_and_value(
109                    md5, self.utf8(attrValue.get("string_value"))
110                )
111            elif attrValue.get("binary_value"):
112                md5.update(bytearray([BINARY_TYPE_FIELD_INDEX]))
113                decoded_binary_value = base64.b64decode(attrValue.get("binary_value"))
114                self.update_binary_length_and_value(md5, decoded_binary_value)
115            # string_list_value type is not implemented, reserved for the future use.
116            # See https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_MessageAttributeValue.html
117            elif len(attrValue["string_list_value"]) > 0:
118                md5.update(bytearray([STRING_LIST_TYPE_FIELD_INDEX]))
119                for strListMember in attrValue["string_list_value"]:
120                    self.update_binary_length_and_value(md5, self.utf8(strListMember))
121            # binary_list_value type is not implemented, reserved for the future use.
122            # See https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_MessageAttributeValue.html
123            elif len(attrValue["binary_list_value"]) > 0:
124                md5.update(bytearray([BINARY_LIST_TYPE_FIELD_INDEX]))
125                for strListMember in attrValue["binary_list_value"]:
126                    decoded_binary_value = base64.b64decode(strListMember)
127                    self.update_binary_length_and_value(md5, decoded_binary_value)
128
129        return md5.hexdigest()
130
131    @staticmethod
132    def update_binary_length_and_value(md5, value):
133        length_bytes = struct.pack("!I".encode("ascii"), len(value))
134        md5.update(length_bytes)
135        md5.update(value)
136
137    @staticmethod
138    def validate_attribute_name(name):
139        if not ATTRIBUTE_NAME_PATTERN.match(name):
140            raise MessageAttributesInvalid(
141                "The message attribute name '{0}' is invalid. "
142                "Attribute name can contain A-Z, a-z, 0-9, "
143                "underscore (_), hyphen (-), and period (.) characters.".format(name)
144            )
145
146    @staticmethod
147    def utf8(string):
148        if isinstance(string, str):
149            return string.encode("utf-8")
150        return string
151
152    @property
153    def body(self):
154        return escape(self._body)
155
156    def mark_sent(self, delay_seconds=None):
157        self.sent_timestamp = int(unix_time_millis())
158        if delay_seconds:
159            self.delay(delay_seconds=delay_seconds)
160
161    def mark_received(self, visibility_timeout=None):
162        """
163        When a message is received we will set the first receive timestamp,
164        tap the ``approximate_receive_count`` and the ``visible_at`` time.
165        """
166        if visibility_timeout:
167            visibility_timeout = int(visibility_timeout)
168        else:
169            visibility_timeout = 0
170
171        if not self.approximate_first_receive_timestamp:
172            self.approximate_first_receive_timestamp = int(unix_time_millis())
173
174        self.approximate_receive_count += 1
175
176        # Make message visible again in the future unless its
177        # destroyed.
178        if visibility_timeout:
179            self.change_visibility(visibility_timeout)
180
181        self.receipt_handle = generate_receipt_handle()
182
183    def change_visibility(self, visibility_timeout):
184        # We're dealing with milliseconds internally
185        visibility_timeout_msec = int(visibility_timeout) * 1000
186        self.visible_at = unix_time_millis() + visibility_timeout_msec
187
188    def delay(self, delay_seconds):
189        delay_msec = int(delay_seconds) * 1000
190        self.delayed_until = unix_time_millis() + delay_msec
191
192    @property
193    def visible(self):
194        current_time = unix_time_millis()
195        if current_time > self.visible_at:
196            return True
197        return False
198
199    @property
200    def delayed(self):
201        current_time = unix_time_millis()
202        if current_time < self.delayed_until:
203            return True
204        return False
205
206
207class Queue(CloudFormationModel):
208    BASE_ATTRIBUTES = [
209        "ApproximateNumberOfMessages",
210        "ApproximateNumberOfMessagesDelayed",
211        "ApproximateNumberOfMessagesNotVisible",
212        "CreatedTimestamp",
213        "DelaySeconds",
214        "LastModifiedTimestamp",
215        "MaximumMessageSize",
216        "MessageRetentionPeriod",
217        "QueueArn",
218        "Policy",
219        "RedrivePolicy",
220        "ReceiveMessageWaitTimeSeconds",
221        "VisibilityTimeout",
222    ]
223    FIFO_ATTRIBUTES = [
224        "ContentBasedDeduplication",
225        "DeduplicationScope",
226        "FifoQueue",
227        "FifoThroughputLimit",
228    ]
229    KMS_ATTRIBUTES = ["KmsDataKeyReusePeriodSeconds", "KmsMasterKeyId"]
230    ALLOWED_PERMISSIONS = (
231        "*",
232        "ChangeMessageVisibility",
233        "DeleteMessage",
234        "GetQueueAttributes",
235        "GetQueueUrl",
236        "ListDeadLetterSourceQueues",
237        "PurgeQueue",
238        "ReceiveMessage",
239        "SendMessage",
240    )
241
242    def __init__(self, name, region, **kwargs):
243        self.name = name
244        self.region = region
245        self.tags = {}
246        self.permissions = {}
247
248        self._messages = []
249        self._pending_messages = set()
250
251        now = unix_time()
252        self.created_timestamp = now
253        self.queue_arn = "arn:aws:sqs:{0}:{1}:{2}".format(
254            self.region, DEFAULT_ACCOUNT_ID, self.name
255        )
256        self.dead_letter_queue = None
257
258        self.lambda_event_source_mappings = {}
259
260        # default settings for a non fifo queue
261        defaults = {
262            "ContentBasedDeduplication": "false",
263            "DeduplicationScope": "queue",
264            "DelaySeconds": 0,
265            "FifoQueue": "false",
266            "FifoThroughputLimit": "perQueue",
267            "KmsDataKeyReusePeriodSeconds": 300,  # five minutes
268            "KmsMasterKeyId": None,
269            "MaximumMessageSize": MAXIMUM_MESSAGE_LENGTH,
270            "MessageRetentionPeriod": 86400 * 4,  # four days
271            "Policy": None,
272            "ReceiveMessageWaitTimeSeconds": 0,
273            "RedrivePolicy": None,
274            "VisibilityTimeout": 30,
275        }
276
277        defaults.update(kwargs)
278        self._set_attributes(defaults, now)
279
280        # Check some conditions
281        if self.fifo_queue and not self.name.endswith(".fifo"):
282            raise InvalidParameterValue("Queue name must end in .fifo for FIFO queues")
283        if (
284            self.maximum_message_size < MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND
285            or self.maximum_message_size > MAXIMUM_MESSAGE_SIZE_ATTR_UPPER_BOUND
286        ):
287            raise InvalidAttributeValue("MaximumMessageSize")
288
289    @property
290    def pending_messages(self):
291        return self._pending_messages
292
293    @property
294    def pending_message_groups(self):
295        return set(
296            message.group_id
297            for message in self._pending_messages
298            if message.group_id is not None
299        )
300
301    def _set_attributes(self, attributes, now=None):
302        if not now:
303            now = unix_time()
304
305        integer_fields = (
306            "DelaySeconds",
307            "KmsDataKeyreusePeriodSeconds",
308            "MaximumMessageSize",
309            "MessageRetentionPeriod",
310            "ReceiveMessageWaitTime",
311            "VisibilityTimeout",
312        )
313        bool_fields = ("ContentBasedDeduplication", "FifoQueue")
314
315        for key, value in attributes.items():
316            if key in integer_fields:
317                value = int(value)
318            if key in bool_fields:
319                value = value == "true"
320
321            if key in ["Policy", "RedrivePolicy"] and value is not None:
322                continue
323
324            setattr(self, camelcase_to_underscores(key), value)
325
326        if attributes.get("RedrivePolicy", None):
327            self._setup_dlq(attributes["RedrivePolicy"])
328
329        if attributes.get("Policy"):
330            self.policy = attributes["Policy"]
331
332        self.last_modified_timestamp = now
333
334    def _setup_dlq(self, policy):
335
336        if isinstance(policy, str):
337            try:
338                self.redrive_policy = json.loads(policy)
339            except ValueError:
340                raise RESTError(
341                    "InvalidParameterValue",
342                    "Redrive policy is not a dict or valid json",
343                )
344        elif isinstance(policy, dict):
345            self.redrive_policy = policy
346        else:
347            raise RESTError(
348                "InvalidParameterValue", "Redrive policy is not a dict or valid json"
349            )
350
351        if "deadLetterTargetArn" not in self.redrive_policy:
352            raise RESTError(
353                "InvalidParameterValue",
354                "Redrive policy does not contain deadLetterTargetArn",
355            )
356        if "maxReceiveCount" not in self.redrive_policy:
357            raise RESTError(
358                "InvalidParameterValue",
359                "Redrive policy does not contain maxReceiveCount",
360            )
361
362        # 'maxReceiveCount' is stored as int
363        self.redrive_policy["maxReceiveCount"] = int(
364            self.redrive_policy["maxReceiveCount"]
365        )
366
367        for queue in sqs_backends[self.region].queues.values():
368            if queue.queue_arn == self.redrive_policy["deadLetterTargetArn"]:
369                self.dead_letter_queue = queue
370
371                if self.fifo_queue and not queue.fifo_queue:
372                    raise RESTError(
373                        "InvalidParameterCombination",
374                        "Fifo queues cannot use non fifo dead letter queues",
375                    )
376                break
377        else:
378            raise RESTError(
379                "AWS.SimpleQueueService.NonExistentQueue",
380                "Could not find DLQ for {0}".format(
381                    self.redrive_policy["deadLetterTargetArn"]
382                ),
383            )
384
385    @staticmethod
386    def cloudformation_name_type():
387        return "QueueName"
388
389    @staticmethod
390    def cloudformation_type():
391        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sqs-queue.html
392        return "AWS::SQS::Queue"
393
394    @classmethod
395    def create_from_cloudformation_json(
396        cls, resource_name, cloudformation_json, region_name, **kwargs
397    ):
398        properties = deepcopy(cloudformation_json["Properties"])
399        # remove Tags from properties and convert tags list to dict
400        tags = properties.pop("Tags", [])
401        tags_dict = tags_from_cloudformation_tags_list(tags)
402
403        # Could be passed as an integer - just treat it as a string
404        resource_name = str(resource_name)
405
406        sqs_backend = sqs_backends[region_name]
407        return sqs_backend.create_queue(
408            name=resource_name, tags=tags_dict, region=region_name, **properties
409        )
410
411    @classmethod
412    def update_from_cloudformation_json(
413        cls, original_resource, new_resource_name, cloudformation_json, region_name
414    ):
415        properties = cloudformation_json["Properties"]
416        queue_name = original_resource.name
417
418        sqs_backend = sqs_backends[region_name]
419        queue = sqs_backend.get_queue(queue_name)
420        if "VisibilityTimeout" in properties:
421            queue.visibility_timeout = int(properties["VisibilityTimeout"])
422
423        if "ReceiveMessageWaitTimeSeconds" in properties:
424            queue.receive_message_wait_time_seconds = int(
425                properties["ReceiveMessageWaitTimeSeconds"]
426            )
427        return queue
428
429    @classmethod
430    def delete_from_cloudformation_json(
431        cls, resource_name, cloudformation_json, region_name
432    ):
433        sqs_backend = sqs_backends[region_name]
434        sqs_backend.delete_queue(resource_name)
435
436    @property
437    def approximate_number_of_messages_delayed(self):
438        return len([m for m in self._messages if m.delayed])
439
440    @property
441    def approximate_number_of_messages_not_visible(self):
442        return len([m for m in self._messages if not m.visible])
443
444    @property
445    def approximate_number_of_messages(self):
446        return len(self.messages)
447
448    @property
449    def physical_resource_id(self):
450        return self.name
451
452    @property
453    def attributes(self):
454        result = {}
455
456        for attribute in self.BASE_ATTRIBUTES:
457            attr = getattr(self, camelcase_to_underscores(attribute))
458            result[attribute] = attr
459
460        if self.fifo_queue:
461            for attribute in self.FIFO_ATTRIBUTES:
462                attr = getattr(self, camelcase_to_underscores(attribute))
463                result[attribute] = attr
464
465        if self.kms_master_key_id:
466            for attribute in self.KMS_ATTRIBUTES:
467                attr = getattr(self, camelcase_to_underscores(attribute))
468                result[attribute] = attr
469
470        if self.policy:
471            result["Policy"] = self.policy
472
473        if self.redrive_policy:
474            result["RedrivePolicy"] = json.dumps(self.redrive_policy)
475
476        for key in result:
477            if isinstance(result[key], bool):
478                result[key] = str(result[key]).lower()
479
480        return result
481
482    def url(self, request_url):
483        return "{0}://{1}/{2}/{3}".format(
484            request_url.scheme, request_url.netloc, DEFAULT_ACCOUNT_ID, self.name
485        )
486
487    @property
488    def messages(self):
489        # TODO: This can become very inefficient if a large number of messages are in-flight
490        return [
491            message
492            for message in self._messages
493            if message.visible and not message.delayed
494        ]
495
496    def add_message(self, message):
497        if (
498            self.fifo_queue
499            and self.attributes.get("ContentBasedDeduplication") == "true"
500        ):
501            for m in self._messages:
502                if m.deduplication_id == message.deduplication_id:
503                    diff = message.sent_timestamp - m.sent_timestamp
504                    # if a duplicate message is received within the deduplication time then it should
505                    # not be added to the queue
506                    if diff / 1000 < DEDUPLICATION_TIME_IN_SECONDS:
507                        return
508
509        self._messages.append(message)
510
511        for arn, esm in self.lambda_event_source_mappings.items():
512            backend = sqs_backends[self.region]
513
514            """
515            Lambda polls the queue and invokes your function synchronously with an event
516            that contains queue messages. Lambda reads messages in batches and invokes
517            your function once for each batch. When your function successfully processes
518            a batch, Lambda deletes its messages from the queue.
519            """
520            messages = backend.receive_messages(
521                self.name,
522                esm.batch_size,
523                self.receive_message_wait_time_seconds,
524                self.visibility_timeout,
525            )
526
527            from moto.awslambda import lambda_backends
528
529            result = lambda_backends[self.region].send_sqs_batch(
530                arn, messages, self.queue_arn
531            )
532
533            if result:
534                [backend.delete_message(self.name, m.receipt_handle) for m in messages]
535            else:
536                # Make messages visible again
537                [
538                    backend.change_message_visibility(
539                        self.name, m.receipt_handle, visibility_timeout=0
540                    )
541                    for m in messages
542                ]
543
544    @classmethod
545    def has_cfn_attr(cls, attribute_name):
546        return attribute_name in ["Arn", "QueueName"]
547
548    def get_cfn_attribute(self, attribute_name):
549        from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
550
551        if attribute_name == "Arn":
552            return self.queue_arn
553        elif attribute_name == "QueueName":
554            return self.name
555        raise UnformattedGetAttTemplateException()
556
557    @property
558    def policy(self):
559        if self._policy_json.get("Statement"):
560            return json.dumps(self._policy_json)
561        else:
562            return None
563
564    @policy.setter
565    def policy(self, policy):
566        if policy:
567            self._policy_json = json.loads(policy)
568        else:
569            self._policy_json = {
570                "Version": "2012-10-17",
571                "Id": "{}/SQSDefaultPolicy".format(self.queue_arn),
572                "Statement": [],
573            }
574
575
576def _filter_message_attributes(message, input_message_attributes):
577    filtered_message_attributes = {}
578    return_all = "All" in input_message_attributes
579    for key, value in message.message_attributes.items():
580        if return_all or key in input_message_attributes:
581            filtered_message_attributes[key] = value
582    message.message_attributes = filtered_message_attributes
583
584
585class SQSBackend(BaseBackend):
586    def __init__(self, region_name):
587        self.region_name = region_name
588        self.queues: Dict[str, Queue] = {}
589        super(SQSBackend, self).__init__()
590
591    def reset(self):
592        region_name = self.region_name
593        self._reset_model_refs()
594        self.__dict__ = {}
595        self.__init__(region_name)
596
597    @staticmethod
598    def default_vpc_endpoint_service(service_region, zones):
599        """Default VPC endpoint service."""
600        return BaseBackend.default_vpc_endpoint_service_factory(
601            service_region, zones, "sqs"
602        )
603
604    def create_queue(self, name, tags=None, **kwargs):
605        queue = self.queues.get(name)
606        if queue:
607            try:
608                kwargs.pop("region")
609            except KeyError:
610                pass
611
612            new_queue = Queue(name, region=self.region_name, **kwargs)
613
614            queue_attributes = queue.attributes
615            new_queue_attributes = new_queue.attributes
616
617            # only the attributes which are being sent for the queue
618            # creation have to be compared if the queue is existing.
619            for key in kwargs:
620                if queue_attributes.get(key) != new_queue_attributes.get(key):
621                    raise QueueAlreadyExists("The specified queue already exists.")
622        else:
623            try:
624                kwargs.pop("region")
625            except KeyError:
626                pass
627            queue = Queue(name, region=self.region_name, **kwargs)
628            self.queues[name] = queue
629
630        if tags:
631            queue.tags = tags
632
633        return queue
634
635    def get_queue_url(self, queue_name):
636        return self.get_queue(queue_name)
637
638    def list_queues(self, queue_name_prefix):
639        re_str = ".*"
640        if queue_name_prefix:
641            re_str = "^{0}.*".format(queue_name_prefix)
642        prefix_re = re.compile(re_str)
643        qs = []
644        for name, q in self.queues.items():
645            if prefix_re.search(name):
646                qs.append(q)
647        return qs[:1000]
648
649    def get_queue(self, queue_name):
650        queue = self.queues.get(queue_name)
651        if queue is None:
652            raise QueueDoesNotExist()
653        return queue
654
655    def delete_queue(self, queue_name):
656        self.get_queue(queue_name)
657
658        del self.queues[queue_name]
659
660    def get_queue_attributes(self, queue_name, attribute_names):
661        queue = self.get_queue(queue_name)
662        if not attribute_names:
663            return {}
664
665        valid_names = (
666            ["All"]
667            + queue.BASE_ATTRIBUTES
668            + queue.FIFO_ATTRIBUTES
669            + queue.KMS_ATTRIBUTES
670        )
671        invalid_name = next(
672            (name for name in attribute_names if name not in valid_names), None
673        )
674
675        if invalid_name or invalid_name == "":
676            raise InvalidAttributeName(invalid_name)
677
678        attributes = {}
679
680        if "All" in attribute_names:
681            attributes = queue.attributes
682        else:
683            for name in (name for name in attribute_names if name in queue.attributes):
684                if queue.attributes.get(name) is not None:
685                    attributes[name] = queue.attributes.get(name)
686
687        return attributes
688
689    def set_queue_attributes(self, queue_name, attributes):
690        queue = self.get_queue(queue_name)
691        queue._set_attributes(attributes)
692        return queue
693
694    def send_message(
695        self,
696        queue_name,
697        message_body,
698        message_attributes=None,
699        delay_seconds=None,
700        deduplication_id=None,
701        group_id=None,
702        system_attributes=None,
703    ):
704
705        queue = self.get_queue(queue_name)
706
707        if len(message_body) > queue.maximum_message_size:
708            msg = "One or more parameters are invalid. Reason: Message must be shorter than {} bytes.".format(
709                queue.maximum_message_size
710            )
711            raise InvalidParameterValue(msg)
712
713        if delay_seconds:
714            delay_seconds = int(delay_seconds)
715        else:
716            delay_seconds = queue.delay_seconds
717
718        message_id = get_random_message_id()
719        message = Message(message_id, message_body, system_attributes)
720
721        # if content based deduplication is set then set sha256 hash of the message
722        # as the deduplication_id
723        if queue.attributes.get("ContentBasedDeduplication") == "true":
724            sha256 = hashlib.sha256()
725            sha256.update(message_body.encode("utf-8"))
726            message.deduplication_id = sha256.hexdigest()
727
728        # Attributes, but not *message* attributes
729        if deduplication_id is not None:
730            message.deduplication_id = deduplication_id
731            message.sequence_number = "".join(
732                random.choice(string.digits) for _ in range(20)
733            )
734
735        if group_id is None:
736            # MessageGroupId is a mandatory parameter for all
737            # messages in a fifo queue
738            if queue.fifo_queue:
739                raise MissingParameter("MessageGroupId")
740        else:
741            message.group_id = group_id
742
743        if message_attributes:
744            message.message_attributes = message_attributes
745
746        message.mark_sent(delay_seconds=delay_seconds)
747
748        queue.add_message(message)
749
750        return message
751
752    def send_message_batch(self, queue_name, entries):
753        self.get_queue(queue_name)
754
755        if any(
756            not re.match(r"^[\w-]{1,80}$", entry["Id"]) for entry in entries.values()
757        ):
758            raise InvalidBatchEntryId()
759
760        body_length = next(
761            (
762                len(entry["MessageBody"])
763                for entry in entries.values()
764                if len(entry["MessageBody"]) > MAXIMUM_MESSAGE_LENGTH
765            ),
766            False,
767        )
768        if body_length:
769            raise BatchRequestTooLong(body_length)
770
771        duplicate_id = self._get_first_duplicate_id(
772            [entry["Id"] for entry in entries.values()]
773        )
774        if duplicate_id:
775            raise BatchEntryIdsNotDistinct(duplicate_id)
776
777        if len(entries) > 10:
778            raise TooManyEntriesInBatchRequest(len(entries))
779
780        messages = []
781        for index, entry in entries.items():
782            # Loop through looking for messages
783            message = self.send_message(
784                queue_name,
785                entry["MessageBody"],
786                message_attributes=entry["MessageAttributes"],
787                delay_seconds=entry["DelaySeconds"],
788                group_id=entry.get("MessageGroupId"),
789                deduplication_id=entry.get("MessageDeduplicationId"),
790            )
791            message.user_id = entry["Id"]
792
793            messages.append(message)
794
795        return messages
796
797    def _get_first_duplicate_id(self, ids):
798        unique_ids = set()
799        for id in ids:
800            if id in unique_ids:
801                return id
802            unique_ids.add(id)
803        return None
804
805    def receive_messages(
806        self,
807        queue_name,
808        count,
809        wait_seconds_timeout,
810        visibility_timeout,
811        message_attribute_names=None,
812    ):
813        """
814        Attempt to retrieve visible messages from a queue.
815
816        If a message was read by client and not deleted it is considered to be
817        "inflight" and cannot be read. We make attempts to obtain ``count``
818        messages but we may return less if messages are in-flight or there
819        are simple not enough messages in the queue.
820
821        :param string queue_name: The name of the queue to read from.
822        :param int count: The maximum amount of messages to retrieve.
823        :param int visibility_timeout: The number of seconds the message should remain invisible to other queue readers.
824        :param int wait_seconds_timeout:  The duration (in seconds) for which the call waits for a message to arrive in
825         the queue before returning. If a message is available, the call returns sooner than WaitTimeSeconds
826        """
827        if message_attribute_names is None:
828            message_attribute_names = []
829        queue = self.get_queue(queue_name)
830        result = []
831        previous_result_count = len(result)
832
833        polling_end = unix_time() + wait_seconds_timeout
834        currently_pending_groups = deepcopy(queue.pending_message_groups)
835
836        # queue.messages only contains visible messages
837        while True:
838
839            if result or (wait_seconds_timeout and unix_time() > polling_end):
840                break
841
842            messages_to_dlq = []
843
844            for message in queue.messages:
845                if not message.visible:
846                    continue
847
848                if message in queue.pending_messages:
849                    # The message is pending but is visible again, so the
850                    # consumer must have timed out.
851                    queue.pending_messages.remove(message)
852                    currently_pending_groups = deepcopy(queue.pending_message_groups)
853
854                if message.group_id and queue.fifo_queue:
855                    if message.group_id in currently_pending_groups:
856                        # A previous call is still processing messages in this group, so we cannot deliver this one.
857                        continue
858
859                if (
860                    queue.dead_letter_queue is not None
861                    and queue.redrive_policy
862                    and message.approximate_receive_count
863                    >= queue.redrive_policy["maxReceiveCount"]
864                ):
865                    messages_to_dlq.append(message)
866                    continue
867
868                queue.pending_messages.add(message)
869                message.mark_received(visibility_timeout=visibility_timeout)
870                # Create deepcopy to not mutate the message state when filtering for attributes
871                message_copy = deepcopy(message)
872                _filter_message_attributes(message_copy, message_attribute_names)
873                if not self.is_message_valid_based_on_retention_period(
874                    queue_name, message
875                ):
876                    break
877                result.append(message_copy)
878                if len(result) >= count:
879                    break
880
881            for message in messages_to_dlq:
882                queue._messages.remove(message)
883                queue.dead_letter_queue.add_message(message)
884
885            if previous_result_count == len(result):
886                if wait_seconds_timeout == 0:
887                    # There is timeout and we have added no additional results,
888                    # so break to avoid an infinite loop.
889                    break
890
891                import time
892
893                time.sleep(0.01)
894                continue
895
896            previous_result_count = len(result)
897
898        return result
899
900    def delete_message(self, queue_name, receipt_handle):
901        queue = self.get_queue(queue_name)
902
903        if not any(
904            message.receipt_handle == receipt_handle for message in queue._messages
905        ):
906            raise ReceiptHandleIsInvalid()
907
908        # Delete message from queue regardless of pending state
909        new_messages = []
910        for message in queue._messages:
911            if message.receipt_handle == receipt_handle:
912                queue.pending_messages.discard(message)
913                continue
914            new_messages.append(message)
915        queue._messages = new_messages
916
917    def change_message_visibility(self, queue_name, receipt_handle, visibility_timeout):
918        queue = self.get_queue(queue_name)
919        for message in queue._messages:
920            if message.receipt_handle == receipt_handle:
921                if message.visible:
922                    raise MessageNotInflight
923
924                visibility_timeout_msec = int(visibility_timeout) * 1000
925                given_visibility_timeout = unix_time_millis() + visibility_timeout_msec
926                if given_visibility_timeout - message.sent_timestamp > 43200 * 1000:
927                    raise InvalidParameterValue(
928                        "Value {0} for parameter VisibilityTimeout is invalid. Reason: Total "
929                        "VisibilityTimeout for the message is beyond the limit [43200 seconds]".format(
930                            visibility_timeout
931                        )
932                    )
933
934                message.change_visibility(visibility_timeout)
935                if message.visible:
936                    # If the message is visible again, remove it from pending
937                    # messages.
938                    queue.pending_messages.remove(message)
939                return
940        raise ReceiptHandleIsInvalid
941
942    def purge_queue(self, queue_name):
943        queue = self.get_queue(queue_name)
944        queue._messages = []
945        queue._pending_messages = set()
946
947    def list_dead_letter_source_queues(self, queue_name):
948        dlq = self.get_queue(queue_name)
949
950        queues = []
951        for queue in self.queues.values():
952            if queue.dead_letter_queue is dlq:
953                queues.append(queue)
954
955        return queues
956
957    def add_permission(self, queue_name, actions, account_ids, label):
958        queue = self.get_queue(queue_name)
959
960        if not actions:
961            raise MissingParameter("Actions")
962
963        if not account_ids:
964            raise InvalidParameterValue(
965                "Value [] for parameter PrincipalId is invalid. Reason: Unable to verify."
966            )
967
968        count = len(actions)
969        if count > 7:
970            raise OverLimit(count)
971
972        invalid_action = next(
973            (action for action in actions if action not in Queue.ALLOWED_PERMISSIONS),
974            None,
975        )
976        if invalid_action:
977            raise InvalidParameterValue(
978                "Value SQS:{} for parameter ActionName is invalid. "
979                "Reason: Only the queue owner is allowed to invoke this action.".format(
980                    invalid_action
981                )
982            )
983
984        policy = queue._policy_json
985        statement = next(
986            (
987                statement
988                for statement in policy["Statement"]
989                if statement["Sid"] == label
990            ),
991            None,
992        )
993        if statement:
994            raise InvalidParameterValue(
995                "Value {} for parameter Label is invalid. "
996                "Reason: Already exists.".format(label)
997            )
998
999        principals = [
1000            "arn:aws:iam::{}:root".format(account_id) for account_id in account_ids
1001        ]
1002        actions = ["SQS:{}".format(action) for action in actions]
1003
1004        statement = {
1005            "Sid": label,
1006            "Effect": "Allow",
1007            "Principal": {"AWS": principals[0] if len(principals) == 1 else principals},
1008            "Action": actions[0] if len(actions) == 1 else actions,
1009            "Resource": queue.queue_arn,
1010        }
1011
1012        queue._policy_json["Statement"].append(statement)
1013
1014    def remove_permission(self, queue_name, label):
1015        queue = self.get_queue(queue_name)
1016
1017        statements = queue._policy_json["Statement"]
1018        statements_new = [
1019            statement for statement in statements if statement["Sid"] != label
1020        ]
1021
1022        if len(statements) == len(statements_new):
1023            raise InvalidParameterValue(
1024                "Value {} for parameter Label is invalid. "
1025                "Reason: can't find label on existing policy.".format(label)
1026            )
1027
1028        queue._policy_json["Statement"] = statements_new
1029
1030    def tag_queue(self, queue_name, tags):
1031        queue = self.get_queue(queue_name)
1032
1033        if not len(tags):
1034            raise MissingParameter("Tags")
1035
1036        if len(tags) > 50:
1037            raise InvalidParameterValue(
1038                "Too many tags added for queue {}.".format(queue_name)
1039            )
1040
1041        queue.tags.update(tags)
1042
1043    def untag_queue(self, queue_name, tag_keys):
1044        queue = self.get_queue(queue_name)
1045
1046        if not len(tag_keys):
1047            raise RESTError(
1048                "InvalidParameterValue",
1049                "Tag keys must be between 1 and 128 characters in length.",
1050            )
1051
1052        for key in tag_keys:
1053            try:
1054                del queue.tags[key]
1055            except KeyError:
1056                pass
1057
1058    def list_queue_tags(self, queue_name):
1059        return self.get_queue(queue_name)
1060
1061    def is_message_valid_based_on_retention_period(self, queue_name, message):
1062        message_attributes = self.get_queue_attributes(
1063            queue_name, ["MessageRetentionPeriod"]
1064        )
1065        retain_until = (
1066            message_attributes.get("MessageRetentionPeriod")
1067            + message.sent_timestamp / 1000
1068        )
1069        if retain_until <= unix_time():
1070            return False
1071        return True
1072
1073
1074sqs_backends = {}
1075for region in Session().get_available_regions("sqs"):
1076    sqs_backends[region] = SQSBackend(region)
1077for region in Session().get_available_regions("sqs", partition_name="aws-us-gov"):
1078    sqs_backends[region] = SQSBackend(region)
1079for region in Session().get_available_regions("sqs", partition_name="aws-cn"):
1080    sqs_backends[region] = SQSBackend(region)
1081