1import base64 2import hashlib 3import json 4import random 5import re 6import string 7 8import struct 9from copy import deepcopy 10from typing import Dict 11from xml.sax.saxutils import escape 12 13from boto3 import Session 14 15from moto.core.exceptions import RESTError 16from moto.core import BaseBackend, BaseModel, CloudFormationModel 17from moto.core.utils import ( 18 camelcase_to_underscores, 19 get_random_message_id, 20 unix_time, 21 unix_time_millis, 22 tags_from_cloudformation_tags_list, 23) 24from .utils import generate_receipt_handle 25from .exceptions import ( 26 MessageAttributesInvalid, 27 MessageNotInflight, 28 QueueDoesNotExist, 29 QueueAlreadyExists, 30 ReceiptHandleIsInvalid, 31 InvalidBatchEntryId, 32 BatchRequestTooLong, 33 BatchEntryIdsNotDistinct, 34 TooManyEntriesInBatchRequest, 35 InvalidAttributeName, 36 InvalidParameterValue, 37 MissingParameter, 38 OverLimit, 39 InvalidAttributeValue, 40) 41 42from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID 43 44DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU" 45 46MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB 47 48MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND = 1024 49MAXIMUM_MESSAGE_SIZE_ATTR_UPPER_BOUND = MAXIMUM_MESSAGE_LENGTH 50 51TRANSPORT_TYPE_ENCODINGS = { 52 "String": b"\x01", 53 "Binary": b"\x02", 54 "Number": b"\x01", 55 "String.custom": b"\x01", 56} 57 58STRING_TYPE_FIELD_INDEX = 1 59BINARY_TYPE_FIELD_INDEX = 2 60STRING_LIST_TYPE_FIELD_INDEX = 3 61BINARY_LIST_TYPE_FIELD_INDEX = 4 62 63# Valid attribute name rules can found at 64# https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/sqs-message-metadata.html 65ATTRIBUTE_NAME_PATTERN = re.compile("^([a-z]|[A-Z]|[0-9]|[_.\\-])+$") 66 67DEDUPLICATION_TIME_IN_SECONDS = 300 68 69 70class Message(BaseModel): 71 def __init__(self, message_id, body, system_attributes={}): 72 self.id = message_id 73 self._body = body 74 self.message_attributes = {} 75 self.receipt_handle = None 76 self.sender_id = DEFAULT_SENDER_ID 77 self.sent_timestamp = None 78 self.approximate_first_receive_timestamp = None 79 self.approximate_receive_count = 0 80 self.deduplication_id = None 81 self.group_id = None 82 self.sequence_number = None 83 self.visible_at = 0 84 self.delayed_until = 0 85 self.system_attributes = system_attributes 86 87 @property 88 def body_md5(self): 89 md5 = hashlib.md5() 90 md5.update(self._body.encode("utf-8")) 91 return md5.hexdigest() 92 93 @property 94 def attribute_md5(self): 95 96 md5 = hashlib.md5() 97 98 for attrName in sorted(self.message_attributes.keys()): 99 self.validate_attribute_name(attrName) 100 attrValue = self.message_attributes[attrName] 101 # Encode name 102 self.update_binary_length_and_value(md5, self.utf8(attrName)) 103 # Encode type 104 self.update_binary_length_and_value(md5, self.utf8(attrValue["data_type"])) 105 106 if attrValue.get("string_value"): 107 md5.update(bytearray([STRING_TYPE_FIELD_INDEX])) 108 self.update_binary_length_and_value( 109 md5, self.utf8(attrValue.get("string_value")) 110 ) 111 elif attrValue.get("binary_value"): 112 md5.update(bytearray([BINARY_TYPE_FIELD_INDEX])) 113 decoded_binary_value = base64.b64decode(attrValue.get("binary_value")) 114 self.update_binary_length_and_value(md5, decoded_binary_value) 115 # string_list_value type is not implemented, reserved for the future use. 116 # See https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_MessageAttributeValue.html 117 elif len(attrValue["string_list_value"]) > 0: 118 md5.update(bytearray([STRING_LIST_TYPE_FIELD_INDEX])) 119 for strListMember in attrValue["string_list_value"]: 120 self.update_binary_length_and_value(md5, self.utf8(strListMember)) 121 # binary_list_value type is not implemented, reserved for the future use. 122 # See https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_MessageAttributeValue.html 123 elif len(attrValue["binary_list_value"]) > 0: 124 md5.update(bytearray([BINARY_LIST_TYPE_FIELD_INDEX])) 125 for strListMember in attrValue["binary_list_value"]: 126 decoded_binary_value = base64.b64decode(strListMember) 127 self.update_binary_length_and_value(md5, decoded_binary_value) 128 129 return md5.hexdigest() 130 131 @staticmethod 132 def update_binary_length_and_value(md5, value): 133 length_bytes = struct.pack("!I".encode("ascii"), len(value)) 134 md5.update(length_bytes) 135 md5.update(value) 136 137 @staticmethod 138 def validate_attribute_name(name): 139 if not ATTRIBUTE_NAME_PATTERN.match(name): 140 raise MessageAttributesInvalid( 141 "The message attribute name '{0}' is invalid. " 142 "Attribute name can contain A-Z, a-z, 0-9, " 143 "underscore (_), hyphen (-), and period (.) characters.".format(name) 144 ) 145 146 @staticmethod 147 def utf8(string): 148 if isinstance(string, str): 149 return string.encode("utf-8") 150 return string 151 152 @property 153 def body(self): 154 return escape(self._body) 155 156 def mark_sent(self, delay_seconds=None): 157 self.sent_timestamp = int(unix_time_millis()) 158 if delay_seconds: 159 self.delay(delay_seconds=delay_seconds) 160 161 def mark_received(self, visibility_timeout=None): 162 """ 163 When a message is received we will set the first receive timestamp, 164 tap the ``approximate_receive_count`` and the ``visible_at`` time. 165 """ 166 if visibility_timeout: 167 visibility_timeout = int(visibility_timeout) 168 else: 169 visibility_timeout = 0 170 171 if not self.approximate_first_receive_timestamp: 172 self.approximate_first_receive_timestamp = int(unix_time_millis()) 173 174 self.approximate_receive_count += 1 175 176 # Make message visible again in the future unless its 177 # destroyed. 178 if visibility_timeout: 179 self.change_visibility(visibility_timeout) 180 181 self.receipt_handle = generate_receipt_handle() 182 183 def change_visibility(self, visibility_timeout): 184 # We're dealing with milliseconds internally 185 visibility_timeout_msec = int(visibility_timeout) * 1000 186 self.visible_at = unix_time_millis() + visibility_timeout_msec 187 188 def delay(self, delay_seconds): 189 delay_msec = int(delay_seconds) * 1000 190 self.delayed_until = unix_time_millis() + delay_msec 191 192 @property 193 def visible(self): 194 current_time = unix_time_millis() 195 if current_time > self.visible_at: 196 return True 197 return False 198 199 @property 200 def delayed(self): 201 current_time = unix_time_millis() 202 if current_time < self.delayed_until: 203 return True 204 return False 205 206 207class Queue(CloudFormationModel): 208 BASE_ATTRIBUTES = [ 209 "ApproximateNumberOfMessages", 210 "ApproximateNumberOfMessagesDelayed", 211 "ApproximateNumberOfMessagesNotVisible", 212 "CreatedTimestamp", 213 "DelaySeconds", 214 "LastModifiedTimestamp", 215 "MaximumMessageSize", 216 "MessageRetentionPeriod", 217 "QueueArn", 218 "Policy", 219 "RedrivePolicy", 220 "ReceiveMessageWaitTimeSeconds", 221 "VisibilityTimeout", 222 ] 223 FIFO_ATTRIBUTES = [ 224 "ContentBasedDeduplication", 225 "DeduplicationScope", 226 "FifoQueue", 227 "FifoThroughputLimit", 228 ] 229 KMS_ATTRIBUTES = ["KmsDataKeyReusePeriodSeconds", "KmsMasterKeyId"] 230 ALLOWED_PERMISSIONS = ( 231 "*", 232 "ChangeMessageVisibility", 233 "DeleteMessage", 234 "GetQueueAttributes", 235 "GetQueueUrl", 236 "ListDeadLetterSourceQueues", 237 "PurgeQueue", 238 "ReceiveMessage", 239 "SendMessage", 240 ) 241 242 def __init__(self, name, region, **kwargs): 243 self.name = name 244 self.region = region 245 self.tags = {} 246 self.permissions = {} 247 248 self._messages = [] 249 self._pending_messages = set() 250 251 now = unix_time() 252 self.created_timestamp = now 253 self.queue_arn = "arn:aws:sqs:{0}:{1}:{2}".format( 254 self.region, DEFAULT_ACCOUNT_ID, self.name 255 ) 256 self.dead_letter_queue = None 257 258 self.lambda_event_source_mappings = {} 259 260 # default settings for a non fifo queue 261 defaults = { 262 "ContentBasedDeduplication": "false", 263 "DeduplicationScope": "queue", 264 "DelaySeconds": 0, 265 "FifoQueue": "false", 266 "FifoThroughputLimit": "perQueue", 267 "KmsDataKeyReusePeriodSeconds": 300, # five minutes 268 "KmsMasterKeyId": None, 269 "MaximumMessageSize": MAXIMUM_MESSAGE_LENGTH, 270 "MessageRetentionPeriod": 86400 * 4, # four days 271 "Policy": None, 272 "ReceiveMessageWaitTimeSeconds": 0, 273 "RedrivePolicy": None, 274 "VisibilityTimeout": 30, 275 } 276 277 defaults.update(kwargs) 278 self._set_attributes(defaults, now) 279 280 # Check some conditions 281 if self.fifo_queue and not self.name.endswith(".fifo"): 282 raise InvalidParameterValue("Queue name must end in .fifo for FIFO queues") 283 if ( 284 self.maximum_message_size < MAXIMUM_MESSAGE_SIZE_ATTR_LOWER_BOUND 285 or self.maximum_message_size > MAXIMUM_MESSAGE_SIZE_ATTR_UPPER_BOUND 286 ): 287 raise InvalidAttributeValue("MaximumMessageSize") 288 289 @property 290 def pending_messages(self): 291 return self._pending_messages 292 293 @property 294 def pending_message_groups(self): 295 return set( 296 message.group_id 297 for message in self._pending_messages 298 if message.group_id is not None 299 ) 300 301 def _set_attributes(self, attributes, now=None): 302 if not now: 303 now = unix_time() 304 305 integer_fields = ( 306 "DelaySeconds", 307 "KmsDataKeyreusePeriodSeconds", 308 "MaximumMessageSize", 309 "MessageRetentionPeriod", 310 "ReceiveMessageWaitTime", 311 "VisibilityTimeout", 312 ) 313 bool_fields = ("ContentBasedDeduplication", "FifoQueue") 314 315 for key, value in attributes.items(): 316 if key in integer_fields: 317 value = int(value) 318 if key in bool_fields: 319 value = value == "true" 320 321 if key in ["Policy", "RedrivePolicy"] and value is not None: 322 continue 323 324 setattr(self, camelcase_to_underscores(key), value) 325 326 if attributes.get("RedrivePolicy", None): 327 self._setup_dlq(attributes["RedrivePolicy"]) 328 329 if attributes.get("Policy"): 330 self.policy = attributes["Policy"] 331 332 self.last_modified_timestamp = now 333 334 def _setup_dlq(self, policy): 335 336 if isinstance(policy, str): 337 try: 338 self.redrive_policy = json.loads(policy) 339 except ValueError: 340 raise RESTError( 341 "InvalidParameterValue", 342 "Redrive policy is not a dict or valid json", 343 ) 344 elif isinstance(policy, dict): 345 self.redrive_policy = policy 346 else: 347 raise RESTError( 348 "InvalidParameterValue", "Redrive policy is not a dict or valid json" 349 ) 350 351 if "deadLetterTargetArn" not in self.redrive_policy: 352 raise RESTError( 353 "InvalidParameterValue", 354 "Redrive policy does not contain deadLetterTargetArn", 355 ) 356 if "maxReceiveCount" not in self.redrive_policy: 357 raise RESTError( 358 "InvalidParameterValue", 359 "Redrive policy does not contain maxReceiveCount", 360 ) 361 362 # 'maxReceiveCount' is stored as int 363 self.redrive_policy["maxReceiveCount"] = int( 364 self.redrive_policy["maxReceiveCount"] 365 ) 366 367 for queue in sqs_backends[self.region].queues.values(): 368 if queue.queue_arn == self.redrive_policy["deadLetterTargetArn"]: 369 self.dead_letter_queue = queue 370 371 if self.fifo_queue and not queue.fifo_queue: 372 raise RESTError( 373 "InvalidParameterCombination", 374 "Fifo queues cannot use non fifo dead letter queues", 375 ) 376 break 377 else: 378 raise RESTError( 379 "AWS.SimpleQueueService.NonExistentQueue", 380 "Could not find DLQ for {0}".format( 381 self.redrive_policy["deadLetterTargetArn"] 382 ), 383 ) 384 385 @staticmethod 386 def cloudformation_name_type(): 387 return "QueueName" 388 389 @staticmethod 390 def cloudformation_type(): 391 # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-sqs-queue.html 392 return "AWS::SQS::Queue" 393 394 @classmethod 395 def create_from_cloudformation_json( 396 cls, resource_name, cloudformation_json, region_name, **kwargs 397 ): 398 properties = deepcopy(cloudformation_json["Properties"]) 399 # remove Tags from properties and convert tags list to dict 400 tags = properties.pop("Tags", []) 401 tags_dict = tags_from_cloudformation_tags_list(tags) 402 403 # Could be passed as an integer - just treat it as a string 404 resource_name = str(resource_name) 405 406 sqs_backend = sqs_backends[region_name] 407 return sqs_backend.create_queue( 408 name=resource_name, tags=tags_dict, region=region_name, **properties 409 ) 410 411 @classmethod 412 def update_from_cloudformation_json( 413 cls, original_resource, new_resource_name, cloudformation_json, region_name 414 ): 415 properties = cloudformation_json["Properties"] 416 queue_name = original_resource.name 417 418 sqs_backend = sqs_backends[region_name] 419 queue = sqs_backend.get_queue(queue_name) 420 if "VisibilityTimeout" in properties: 421 queue.visibility_timeout = int(properties["VisibilityTimeout"]) 422 423 if "ReceiveMessageWaitTimeSeconds" in properties: 424 queue.receive_message_wait_time_seconds = int( 425 properties["ReceiveMessageWaitTimeSeconds"] 426 ) 427 return queue 428 429 @classmethod 430 def delete_from_cloudformation_json( 431 cls, resource_name, cloudformation_json, region_name 432 ): 433 sqs_backend = sqs_backends[region_name] 434 sqs_backend.delete_queue(resource_name) 435 436 @property 437 def approximate_number_of_messages_delayed(self): 438 return len([m for m in self._messages if m.delayed]) 439 440 @property 441 def approximate_number_of_messages_not_visible(self): 442 return len([m for m in self._messages if not m.visible]) 443 444 @property 445 def approximate_number_of_messages(self): 446 return len(self.messages) 447 448 @property 449 def physical_resource_id(self): 450 return self.name 451 452 @property 453 def attributes(self): 454 result = {} 455 456 for attribute in self.BASE_ATTRIBUTES: 457 attr = getattr(self, camelcase_to_underscores(attribute)) 458 result[attribute] = attr 459 460 if self.fifo_queue: 461 for attribute in self.FIFO_ATTRIBUTES: 462 attr = getattr(self, camelcase_to_underscores(attribute)) 463 result[attribute] = attr 464 465 if self.kms_master_key_id: 466 for attribute in self.KMS_ATTRIBUTES: 467 attr = getattr(self, camelcase_to_underscores(attribute)) 468 result[attribute] = attr 469 470 if self.policy: 471 result["Policy"] = self.policy 472 473 if self.redrive_policy: 474 result["RedrivePolicy"] = json.dumps(self.redrive_policy) 475 476 for key in result: 477 if isinstance(result[key], bool): 478 result[key] = str(result[key]).lower() 479 480 return result 481 482 def url(self, request_url): 483 return "{0}://{1}/{2}/{3}".format( 484 request_url.scheme, request_url.netloc, DEFAULT_ACCOUNT_ID, self.name 485 ) 486 487 @property 488 def messages(self): 489 # TODO: This can become very inefficient if a large number of messages are in-flight 490 return [ 491 message 492 for message in self._messages 493 if message.visible and not message.delayed 494 ] 495 496 def add_message(self, message): 497 if ( 498 self.fifo_queue 499 and self.attributes.get("ContentBasedDeduplication") == "true" 500 ): 501 for m in self._messages: 502 if m.deduplication_id == message.deduplication_id: 503 diff = message.sent_timestamp - m.sent_timestamp 504 # if a duplicate message is received within the deduplication time then it should 505 # not be added to the queue 506 if diff / 1000 < DEDUPLICATION_TIME_IN_SECONDS: 507 return 508 509 self._messages.append(message) 510 511 for arn, esm in self.lambda_event_source_mappings.items(): 512 backend = sqs_backends[self.region] 513 514 """ 515 Lambda polls the queue and invokes your function synchronously with an event 516 that contains queue messages. Lambda reads messages in batches and invokes 517 your function once for each batch. When your function successfully processes 518 a batch, Lambda deletes its messages from the queue. 519 """ 520 messages = backend.receive_messages( 521 self.name, 522 esm.batch_size, 523 self.receive_message_wait_time_seconds, 524 self.visibility_timeout, 525 ) 526 527 from moto.awslambda import lambda_backends 528 529 result = lambda_backends[self.region].send_sqs_batch( 530 arn, messages, self.queue_arn 531 ) 532 533 if result: 534 [backend.delete_message(self.name, m.receipt_handle) for m in messages] 535 else: 536 # Make messages visible again 537 [ 538 backend.change_message_visibility( 539 self.name, m.receipt_handle, visibility_timeout=0 540 ) 541 for m in messages 542 ] 543 544 @classmethod 545 def has_cfn_attr(cls, attribute_name): 546 return attribute_name in ["Arn", "QueueName"] 547 548 def get_cfn_attribute(self, attribute_name): 549 from moto.cloudformation.exceptions import UnformattedGetAttTemplateException 550 551 if attribute_name == "Arn": 552 return self.queue_arn 553 elif attribute_name == "QueueName": 554 return self.name 555 raise UnformattedGetAttTemplateException() 556 557 @property 558 def policy(self): 559 if self._policy_json.get("Statement"): 560 return json.dumps(self._policy_json) 561 else: 562 return None 563 564 @policy.setter 565 def policy(self, policy): 566 if policy: 567 self._policy_json = json.loads(policy) 568 else: 569 self._policy_json = { 570 "Version": "2012-10-17", 571 "Id": "{}/SQSDefaultPolicy".format(self.queue_arn), 572 "Statement": [], 573 } 574 575 576def _filter_message_attributes(message, input_message_attributes): 577 filtered_message_attributes = {} 578 return_all = "All" in input_message_attributes 579 for key, value in message.message_attributes.items(): 580 if return_all or key in input_message_attributes: 581 filtered_message_attributes[key] = value 582 message.message_attributes = filtered_message_attributes 583 584 585class SQSBackend(BaseBackend): 586 def __init__(self, region_name): 587 self.region_name = region_name 588 self.queues: Dict[str, Queue] = {} 589 super(SQSBackend, self).__init__() 590 591 def reset(self): 592 region_name = self.region_name 593 self._reset_model_refs() 594 self.__dict__ = {} 595 self.__init__(region_name) 596 597 @staticmethod 598 def default_vpc_endpoint_service(service_region, zones): 599 """Default VPC endpoint service.""" 600 return BaseBackend.default_vpc_endpoint_service_factory( 601 service_region, zones, "sqs" 602 ) 603 604 def create_queue(self, name, tags=None, **kwargs): 605 queue = self.queues.get(name) 606 if queue: 607 try: 608 kwargs.pop("region") 609 except KeyError: 610 pass 611 612 new_queue = Queue(name, region=self.region_name, **kwargs) 613 614 queue_attributes = queue.attributes 615 new_queue_attributes = new_queue.attributes 616 617 # only the attributes which are being sent for the queue 618 # creation have to be compared if the queue is existing. 619 for key in kwargs: 620 if queue_attributes.get(key) != new_queue_attributes.get(key): 621 raise QueueAlreadyExists("The specified queue already exists.") 622 else: 623 try: 624 kwargs.pop("region") 625 except KeyError: 626 pass 627 queue = Queue(name, region=self.region_name, **kwargs) 628 self.queues[name] = queue 629 630 if tags: 631 queue.tags = tags 632 633 return queue 634 635 def get_queue_url(self, queue_name): 636 return self.get_queue(queue_name) 637 638 def list_queues(self, queue_name_prefix): 639 re_str = ".*" 640 if queue_name_prefix: 641 re_str = "^{0}.*".format(queue_name_prefix) 642 prefix_re = re.compile(re_str) 643 qs = [] 644 for name, q in self.queues.items(): 645 if prefix_re.search(name): 646 qs.append(q) 647 return qs[:1000] 648 649 def get_queue(self, queue_name): 650 queue = self.queues.get(queue_name) 651 if queue is None: 652 raise QueueDoesNotExist() 653 return queue 654 655 def delete_queue(self, queue_name): 656 self.get_queue(queue_name) 657 658 del self.queues[queue_name] 659 660 def get_queue_attributes(self, queue_name, attribute_names): 661 queue = self.get_queue(queue_name) 662 if not attribute_names: 663 return {} 664 665 valid_names = ( 666 ["All"] 667 + queue.BASE_ATTRIBUTES 668 + queue.FIFO_ATTRIBUTES 669 + queue.KMS_ATTRIBUTES 670 ) 671 invalid_name = next( 672 (name for name in attribute_names if name not in valid_names), None 673 ) 674 675 if invalid_name or invalid_name == "": 676 raise InvalidAttributeName(invalid_name) 677 678 attributes = {} 679 680 if "All" in attribute_names: 681 attributes = queue.attributes 682 else: 683 for name in (name for name in attribute_names if name in queue.attributes): 684 if queue.attributes.get(name) is not None: 685 attributes[name] = queue.attributes.get(name) 686 687 return attributes 688 689 def set_queue_attributes(self, queue_name, attributes): 690 queue = self.get_queue(queue_name) 691 queue._set_attributes(attributes) 692 return queue 693 694 def send_message( 695 self, 696 queue_name, 697 message_body, 698 message_attributes=None, 699 delay_seconds=None, 700 deduplication_id=None, 701 group_id=None, 702 system_attributes=None, 703 ): 704 705 queue = self.get_queue(queue_name) 706 707 if len(message_body) > queue.maximum_message_size: 708 msg = "One or more parameters are invalid. Reason: Message must be shorter than {} bytes.".format( 709 queue.maximum_message_size 710 ) 711 raise InvalidParameterValue(msg) 712 713 if delay_seconds: 714 delay_seconds = int(delay_seconds) 715 else: 716 delay_seconds = queue.delay_seconds 717 718 message_id = get_random_message_id() 719 message = Message(message_id, message_body, system_attributes) 720 721 # if content based deduplication is set then set sha256 hash of the message 722 # as the deduplication_id 723 if queue.attributes.get("ContentBasedDeduplication") == "true": 724 sha256 = hashlib.sha256() 725 sha256.update(message_body.encode("utf-8")) 726 message.deduplication_id = sha256.hexdigest() 727 728 # Attributes, but not *message* attributes 729 if deduplication_id is not None: 730 message.deduplication_id = deduplication_id 731 message.sequence_number = "".join( 732 random.choice(string.digits) for _ in range(20) 733 ) 734 735 if group_id is None: 736 # MessageGroupId is a mandatory parameter for all 737 # messages in a fifo queue 738 if queue.fifo_queue: 739 raise MissingParameter("MessageGroupId") 740 else: 741 message.group_id = group_id 742 743 if message_attributes: 744 message.message_attributes = message_attributes 745 746 message.mark_sent(delay_seconds=delay_seconds) 747 748 queue.add_message(message) 749 750 return message 751 752 def send_message_batch(self, queue_name, entries): 753 self.get_queue(queue_name) 754 755 if any( 756 not re.match(r"^[\w-]{1,80}$", entry["Id"]) for entry in entries.values() 757 ): 758 raise InvalidBatchEntryId() 759 760 body_length = next( 761 ( 762 len(entry["MessageBody"]) 763 for entry in entries.values() 764 if len(entry["MessageBody"]) > MAXIMUM_MESSAGE_LENGTH 765 ), 766 False, 767 ) 768 if body_length: 769 raise BatchRequestTooLong(body_length) 770 771 duplicate_id = self._get_first_duplicate_id( 772 [entry["Id"] for entry in entries.values()] 773 ) 774 if duplicate_id: 775 raise BatchEntryIdsNotDistinct(duplicate_id) 776 777 if len(entries) > 10: 778 raise TooManyEntriesInBatchRequest(len(entries)) 779 780 messages = [] 781 for index, entry in entries.items(): 782 # Loop through looking for messages 783 message = self.send_message( 784 queue_name, 785 entry["MessageBody"], 786 message_attributes=entry["MessageAttributes"], 787 delay_seconds=entry["DelaySeconds"], 788 group_id=entry.get("MessageGroupId"), 789 deduplication_id=entry.get("MessageDeduplicationId"), 790 ) 791 message.user_id = entry["Id"] 792 793 messages.append(message) 794 795 return messages 796 797 def _get_first_duplicate_id(self, ids): 798 unique_ids = set() 799 for id in ids: 800 if id in unique_ids: 801 return id 802 unique_ids.add(id) 803 return None 804 805 def receive_messages( 806 self, 807 queue_name, 808 count, 809 wait_seconds_timeout, 810 visibility_timeout, 811 message_attribute_names=None, 812 ): 813 """ 814 Attempt to retrieve visible messages from a queue. 815 816 If a message was read by client and not deleted it is considered to be 817 "inflight" and cannot be read. We make attempts to obtain ``count`` 818 messages but we may return less if messages are in-flight or there 819 are simple not enough messages in the queue. 820 821 :param string queue_name: The name of the queue to read from. 822 :param int count: The maximum amount of messages to retrieve. 823 :param int visibility_timeout: The number of seconds the message should remain invisible to other queue readers. 824 :param int wait_seconds_timeout: The duration (in seconds) for which the call waits for a message to arrive in 825 the queue before returning. If a message is available, the call returns sooner than WaitTimeSeconds 826 """ 827 if message_attribute_names is None: 828 message_attribute_names = [] 829 queue = self.get_queue(queue_name) 830 result = [] 831 previous_result_count = len(result) 832 833 polling_end = unix_time() + wait_seconds_timeout 834 currently_pending_groups = deepcopy(queue.pending_message_groups) 835 836 # queue.messages only contains visible messages 837 while True: 838 839 if result or (wait_seconds_timeout and unix_time() > polling_end): 840 break 841 842 messages_to_dlq = [] 843 844 for message in queue.messages: 845 if not message.visible: 846 continue 847 848 if message in queue.pending_messages: 849 # The message is pending but is visible again, so the 850 # consumer must have timed out. 851 queue.pending_messages.remove(message) 852 currently_pending_groups = deepcopy(queue.pending_message_groups) 853 854 if message.group_id and queue.fifo_queue: 855 if message.group_id in currently_pending_groups: 856 # A previous call is still processing messages in this group, so we cannot deliver this one. 857 continue 858 859 if ( 860 queue.dead_letter_queue is not None 861 and queue.redrive_policy 862 and message.approximate_receive_count 863 >= queue.redrive_policy["maxReceiveCount"] 864 ): 865 messages_to_dlq.append(message) 866 continue 867 868 queue.pending_messages.add(message) 869 message.mark_received(visibility_timeout=visibility_timeout) 870 # Create deepcopy to not mutate the message state when filtering for attributes 871 message_copy = deepcopy(message) 872 _filter_message_attributes(message_copy, message_attribute_names) 873 if not self.is_message_valid_based_on_retention_period( 874 queue_name, message 875 ): 876 break 877 result.append(message_copy) 878 if len(result) >= count: 879 break 880 881 for message in messages_to_dlq: 882 queue._messages.remove(message) 883 queue.dead_letter_queue.add_message(message) 884 885 if previous_result_count == len(result): 886 if wait_seconds_timeout == 0: 887 # There is timeout and we have added no additional results, 888 # so break to avoid an infinite loop. 889 break 890 891 import time 892 893 time.sleep(0.01) 894 continue 895 896 previous_result_count = len(result) 897 898 return result 899 900 def delete_message(self, queue_name, receipt_handle): 901 queue = self.get_queue(queue_name) 902 903 if not any( 904 message.receipt_handle == receipt_handle for message in queue._messages 905 ): 906 raise ReceiptHandleIsInvalid() 907 908 # Delete message from queue regardless of pending state 909 new_messages = [] 910 for message in queue._messages: 911 if message.receipt_handle == receipt_handle: 912 queue.pending_messages.discard(message) 913 continue 914 new_messages.append(message) 915 queue._messages = new_messages 916 917 def change_message_visibility(self, queue_name, receipt_handle, visibility_timeout): 918 queue = self.get_queue(queue_name) 919 for message in queue._messages: 920 if message.receipt_handle == receipt_handle: 921 if message.visible: 922 raise MessageNotInflight 923 924 visibility_timeout_msec = int(visibility_timeout) * 1000 925 given_visibility_timeout = unix_time_millis() + visibility_timeout_msec 926 if given_visibility_timeout - message.sent_timestamp > 43200 * 1000: 927 raise InvalidParameterValue( 928 "Value {0} for parameter VisibilityTimeout is invalid. Reason: Total " 929 "VisibilityTimeout for the message is beyond the limit [43200 seconds]".format( 930 visibility_timeout 931 ) 932 ) 933 934 message.change_visibility(visibility_timeout) 935 if message.visible: 936 # If the message is visible again, remove it from pending 937 # messages. 938 queue.pending_messages.remove(message) 939 return 940 raise ReceiptHandleIsInvalid 941 942 def purge_queue(self, queue_name): 943 queue = self.get_queue(queue_name) 944 queue._messages = [] 945 queue._pending_messages = set() 946 947 def list_dead_letter_source_queues(self, queue_name): 948 dlq = self.get_queue(queue_name) 949 950 queues = [] 951 for queue in self.queues.values(): 952 if queue.dead_letter_queue is dlq: 953 queues.append(queue) 954 955 return queues 956 957 def add_permission(self, queue_name, actions, account_ids, label): 958 queue = self.get_queue(queue_name) 959 960 if not actions: 961 raise MissingParameter("Actions") 962 963 if not account_ids: 964 raise InvalidParameterValue( 965 "Value [] for parameter PrincipalId is invalid. Reason: Unable to verify." 966 ) 967 968 count = len(actions) 969 if count > 7: 970 raise OverLimit(count) 971 972 invalid_action = next( 973 (action for action in actions if action not in Queue.ALLOWED_PERMISSIONS), 974 None, 975 ) 976 if invalid_action: 977 raise InvalidParameterValue( 978 "Value SQS:{} for parameter ActionName is invalid. " 979 "Reason: Only the queue owner is allowed to invoke this action.".format( 980 invalid_action 981 ) 982 ) 983 984 policy = queue._policy_json 985 statement = next( 986 ( 987 statement 988 for statement in policy["Statement"] 989 if statement["Sid"] == label 990 ), 991 None, 992 ) 993 if statement: 994 raise InvalidParameterValue( 995 "Value {} for parameter Label is invalid. " 996 "Reason: Already exists.".format(label) 997 ) 998 999 principals = [ 1000 "arn:aws:iam::{}:root".format(account_id) for account_id in account_ids 1001 ] 1002 actions = ["SQS:{}".format(action) for action in actions] 1003 1004 statement = { 1005 "Sid": label, 1006 "Effect": "Allow", 1007 "Principal": {"AWS": principals[0] if len(principals) == 1 else principals}, 1008 "Action": actions[0] if len(actions) == 1 else actions, 1009 "Resource": queue.queue_arn, 1010 } 1011 1012 queue._policy_json["Statement"].append(statement) 1013 1014 def remove_permission(self, queue_name, label): 1015 queue = self.get_queue(queue_name) 1016 1017 statements = queue._policy_json["Statement"] 1018 statements_new = [ 1019 statement for statement in statements if statement["Sid"] != label 1020 ] 1021 1022 if len(statements) == len(statements_new): 1023 raise InvalidParameterValue( 1024 "Value {} for parameter Label is invalid. " 1025 "Reason: can't find label on existing policy.".format(label) 1026 ) 1027 1028 queue._policy_json["Statement"] = statements_new 1029 1030 def tag_queue(self, queue_name, tags): 1031 queue = self.get_queue(queue_name) 1032 1033 if not len(tags): 1034 raise MissingParameter("Tags") 1035 1036 if len(tags) > 50: 1037 raise InvalidParameterValue( 1038 "Too many tags added for queue {}.".format(queue_name) 1039 ) 1040 1041 queue.tags.update(tags) 1042 1043 def untag_queue(self, queue_name, tag_keys): 1044 queue = self.get_queue(queue_name) 1045 1046 if not len(tag_keys): 1047 raise RESTError( 1048 "InvalidParameterValue", 1049 "Tag keys must be between 1 and 128 characters in length.", 1050 ) 1051 1052 for key in tag_keys: 1053 try: 1054 del queue.tags[key] 1055 except KeyError: 1056 pass 1057 1058 def list_queue_tags(self, queue_name): 1059 return self.get_queue(queue_name) 1060 1061 def is_message_valid_based_on_retention_period(self, queue_name, message): 1062 message_attributes = self.get_queue_attributes( 1063 queue_name, ["MessageRetentionPeriod"] 1064 ) 1065 retain_until = ( 1066 message_attributes.get("MessageRetentionPeriod") 1067 + message.sent_timestamp / 1000 1068 ) 1069 if retain_until <= unix_time(): 1070 return False 1071 return True 1072 1073 1074sqs_backends = {} 1075for region in Session().get_available_regions("sqs"): 1076 sqs_backends[region] = SQSBackend(region) 1077for region in Session().get_available_regions("sqs", partition_name="aws-us-gov"): 1078 sqs_backends[region] = SQSBackend(region) 1079for region in Session().get_available_regions("sqs", partition_name="aws-cn"): 1080 sqs_backends[region] = SQSBackend(region) 1081