1import copy 2import datetime 3 4from boto3 import Session 5 6from collections import OrderedDict 7from moto.core import BaseBackend, BaseModel, CloudFormationModel 8from moto.core.utils import iso_8601_datetime_with_milliseconds 9from moto.utilities.utils import random_string 10from moto.ec2 import ec2_backends 11from .exceptions import ( 12 ClusterAlreadyExistsFaultError, 13 ClusterNotFoundError, 14 ClusterParameterGroupNotFoundError, 15 ClusterSecurityGroupNotFoundError, 16 ClusterSnapshotAlreadyExistsError, 17 ClusterSnapshotNotFoundError, 18 ClusterSubnetGroupNotFoundError, 19 InvalidParameterCombinationError, 20 InvalidParameterValueError, 21 InvalidSubnetError, 22 ResourceNotFoundFaultError, 23 SnapshotCopyAlreadyDisabledFaultError, 24 SnapshotCopyAlreadyEnabledFaultError, 25 SnapshotCopyDisabledFaultError, 26 SnapshotCopyGrantAlreadyExistsFaultError, 27 SnapshotCopyGrantNotFoundFaultError, 28 UnknownSnapshotCopyRegionFaultError, 29 ClusterSecurityGroupNotFoundFaultError, 30) 31 32 33from moto.core import ACCOUNT_ID 34 35 36class TaggableResourceMixin(object): 37 38 resource_type = None 39 40 def __init__(self, region_name, tags): 41 self.region = region_name 42 self.tags = tags or [] 43 44 @property 45 def resource_id(self): 46 return None 47 48 @property 49 def arn(self): 50 return "arn:aws:redshift:{region}:{account_id}:{resource_type}:{resource_id}".format( 51 region=self.region, 52 account_id=ACCOUNT_ID, 53 resource_type=self.resource_type, 54 resource_id=self.resource_id, 55 ) 56 57 def create_tags(self, tags): 58 new_keys = [tag_set["Key"] for tag_set in tags] 59 self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] 60 self.tags.extend(tags) 61 return self.tags 62 63 def delete_tags(self, tag_keys): 64 self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] 65 return self.tags 66 67 68class Cluster(TaggableResourceMixin, CloudFormationModel): 69 70 resource_type = "cluster" 71 72 def __init__( 73 self, 74 redshift_backend, 75 cluster_identifier, 76 node_type, 77 master_username, 78 master_user_password, 79 db_name, 80 cluster_type, 81 cluster_security_groups, 82 vpc_security_group_ids, 83 cluster_subnet_group_name, 84 availability_zone, 85 preferred_maintenance_window, 86 cluster_parameter_group_name, 87 automated_snapshot_retention_period, 88 port, 89 cluster_version, 90 allow_version_upgrade, 91 number_of_nodes, 92 publicly_accessible, 93 encrypted, 94 region_name, 95 tags=None, 96 iam_roles_arn=None, 97 enhanced_vpc_routing=None, 98 restored_from_snapshot=False, 99 kms_key_id=None, 100 ): 101 super(Cluster, self).__init__(region_name, tags) 102 self.redshift_backend = redshift_backend 103 self.cluster_identifier = cluster_identifier 104 self.create_time = iso_8601_datetime_with_milliseconds( 105 datetime.datetime.utcnow() 106 ) 107 self.status = "available" 108 self.node_type = node_type 109 self.master_username = master_username 110 self.master_user_password = master_user_password 111 self.db_name = db_name if db_name else "dev" 112 self.vpc_security_group_ids = vpc_security_group_ids 113 self.enhanced_vpc_routing = ( 114 enhanced_vpc_routing if enhanced_vpc_routing is not None else False 115 ) 116 self.cluster_subnet_group_name = cluster_subnet_group_name 117 self.publicly_accessible = publicly_accessible 118 self.encrypted = encrypted 119 120 self.allow_version_upgrade = ( 121 allow_version_upgrade if allow_version_upgrade is not None else True 122 ) 123 self.cluster_version = cluster_version if cluster_version else "1.0" 124 self.port = int(port) if port else 5439 125 self.automated_snapshot_retention_period = ( 126 int(automated_snapshot_retention_period) 127 if automated_snapshot_retention_period 128 else 1 129 ) 130 self.preferred_maintenance_window = ( 131 preferred_maintenance_window 132 if preferred_maintenance_window 133 else "Mon:03:00-Mon:03:30" 134 ) 135 136 if cluster_parameter_group_name: 137 self.cluster_parameter_group_name = [cluster_parameter_group_name] 138 else: 139 self.cluster_parameter_group_name = ["default.redshift-1.0"] 140 141 if cluster_security_groups: 142 self.cluster_security_groups = cluster_security_groups 143 else: 144 self.cluster_security_groups = ["Default"] 145 146 if availability_zone: 147 self.availability_zone = availability_zone 148 else: 149 # This could probably be smarter, but there doesn't appear to be a 150 # way to pull AZs for a region in boto 151 self.availability_zone = region_name + "a" 152 153 if cluster_type == "single-node": 154 self.number_of_nodes = 1 155 elif number_of_nodes: 156 self.number_of_nodes = int(number_of_nodes) 157 else: 158 self.number_of_nodes = 1 159 160 self.iam_roles_arn = iam_roles_arn or [] 161 self.restored_from_snapshot = restored_from_snapshot 162 self.kms_key_id = kms_key_id 163 164 @staticmethod 165 def cloudformation_name_type(): 166 return None 167 168 @staticmethod 169 def cloudformation_type(): 170 # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-cluster.html 171 return "AWS::Redshift::Cluster" 172 173 @classmethod 174 def create_from_cloudformation_json( 175 cls, resource_name, cloudformation_json, region_name, **kwargs 176 ): 177 redshift_backend = redshift_backends[region_name] 178 properties = cloudformation_json["Properties"] 179 180 if "ClusterSubnetGroupName" in properties: 181 subnet_group_name = properties[ 182 "ClusterSubnetGroupName" 183 ].cluster_subnet_group_name 184 else: 185 subnet_group_name = None 186 187 cluster = redshift_backend.create_cluster( 188 cluster_identifier=resource_name, 189 node_type=properties.get("NodeType"), 190 master_username=properties.get("MasterUsername"), 191 master_user_password=properties.get("MasterUserPassword"), 192 db_name=properties.get("DBName"), 193 cluster_type=properties.get("ClusterType"), 194 cluster_security_groups=properties.get("ClusterSecurityGroups", []), 195 vpc_security_group_ids=properties.get("VpcSecurityGroupIds", []), 196 cluster_subnet_group_name=subnet_group_name, 197 availability_zone=properties.get("AvailabilityZone"), 198 preferred_maintenance_window=properties.get("PreferredMaintenanceWindow"), 199 cluster_parameter_group_name=properties.get("ClusterParameterGroupName"), 200 automated_snapshot_retention_period=properties.get( 201 "AutomatedSnapshotRetentionPeriod" 202 ), 203 port=properties.get("Port"), 204 cluster_version=properties.get("ClusterVersion"), 205 allow_version_upgrade=properties.get("AllowVersionUpgrade"), 206 enhanced_vpc_routing=properties.get("EnhancedVpcRouting"), 207 number_of_nodes=properties.get("NumberOfNodes"), 208 publicly_accessible=properties.get("PubliclyAccessible"), 209 encrypted=properties.get("Encrypted"), 210 region_name=region_name, 211 kms_key_id=properties.get("KmsKeyId"), 212 ) 213 return cluster 214 215 @classmethod 216 def has_cfn_attr(cls, attribute): 217 return attribute in ["Endpoint.Address", "Endpoint.Port"] 218 219 def get_cfn_attribute(self, attribute_name): 220 from moto.cloudformation.exceptions import UnformattedGetAttTemplateException 221 222 if attribute_name == "Endpoint.Address": 223 return self.endpoint 224 elif attribute_name == "Endpoint.Port": 225 return self.port 226 raise UnformattedGetAttTemplateException() 227 228 @property 229 def endpoint(self): 230 return "{0}.cg034hpkmmjt.{1}.redshift.amazonaws.com".format( 231 self.cluster_identifier, self.region 232 ) 233 234 @property 235 def security_groups(self): 236 return [ 237 security_group 238 for security_group in self.redshift_backend.describe_cluster_security_groups() 239 if security_group.cluster_security_group_name 240 in self.cluster_security_groups 241 ] 242 243 @property 244 def vpc_security_groups(self): 245 return [ 246 security_group 247 for security_group in self.redshift_backend.ec2_backend.describe_security_groups() 248 if security_group.id in self.vpc_security_group_ids 249 ] 250 251 @property 252 def parameter_groups(self): 253 return [ 254 parameter_group 255 for parameter_group in self.redshift_backend.describe_cluster_parameter_groups() 256 if parameter_group.cluster_parameter_group_name 257 in self.cluster_parameter_group_name 258 ] 259 260 @property 261 def resource_id(self): 262 return self.cluster_identifier 263 264 def to_json(self): 265 json_response = { 266 "MasterUsername": self.master_username, 267 "MasterUserPassword": "****", 268 "ClusterVersion": self.cluster_version, 269 "VpcSecurityGroups": [ 270 {"Status": "active", "VpcSecurityGroupId": group.id} 271 for group in self.vpc_security_groups 272 ], 273 "ClusterSubnetGroupName": self.cluster_subnet_group_name, 274 "AvailabilityZone": self.availability_zone, 275 "ClusterStatus": self.status, 276 "NumberOfNodes": self.number_of_nodes, 277 "AutomatedSnapshotRetentionPeriod": self.automated_snapshot_retention_period, 278 "PubliclyAccessible": self.publicly_accessible, 279 "Encrypted": self.encrypted, 280 "DBName": self.db_name, 281 "PreferredMaintenanceWindow": self.preferred_maintenance_window, 282 "ClusterParameterGroups": [ 283 { 284 "ParameterApplyStatus": "in-sync", 285 "ParameterGroupName": group.cluster_parameter_group_name, 286 } 287 for group in self.parameter_groups 288 ], 289 "ClusterSecurityGroups": [ 290 { 291 "Status": "active", 292 "ClusterSecurityGroupName": group.cluster_security_group_name, 293 } 294 for group in self.security_groups 295 ], 296 "Port": self.port, 297 "NodeType": self.node_type, 298 "ClusterIdentifier": self.cluster_identifier, 299 "AllowVersionUpgrade": self.allow_version_upgrade, 300 "Endpoint": {"Address": self.endpoint, "Port": self.port}, 301 "ClusterCreateTime": self.create_time, 302 "PendingModifiedValues": [], 303 "Tags": self.tags, 304 "EnhancedVpcRouting": self.enhanced_vpc_routing, 305 "IamRoles": [ 306 {"ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn} 307 for iam_role_arn in self.iam_roles_arn 308 ], 309 "KmsKeyId": self.kms_key_id, 310 } 311 if self.restored_from_snapshot: 312 json_response["RestoreStatus"] = { 313 "Status": "completed", 314 "CurrentRestoreRateInMegaBytesPerSecond": 123.0, 315 "SnapshotSizeInMegaBytes": 123, 316 "ProgressInMegaBytes": 123, 317 "ElapsedTimeInSeconds": 123, 318 "EstimatedTimeToCompletionInSeconds": 123, 319 } 320 try: 321 json_response[ 322 "ClusterSnapshotCopyStatus" 323 ] = self.cluster_snapshot_copy_status 324 except AttributeError: 325 pass 326 return json_response 327 328 329class SnapshotCopyGrant(TaggableResourceMixin, BaseModel): 330 331 resource_type = "snapshotcopygrant" 332 333 def __init__(self, snapshot_copy_grant_name, kms_key_id): 334 self.snapshot_copy_grant_name = snapshot_copy_grant_name 335 self.kms_key_id = kms_key_id 336 337 def to_json(self): 338 return { 339 "SnapshotCopyGrantName": self.snapshot_copy_grant_name, 340 "KmsKeyId": self.kms_key_id, 341 } 342 343 344class SubnetGroup(TaggableResourceMixin, CloudFormationModel): 345 346 resource_type = "subnetgroup" 347 348 def __init__( 349 self, 350 ec2_backend, 351 cluster_subnet_group_name, 352 description, 353 subnet_ids, 354 region_name, 355 tags=None, 356 ): 357 super(SubnetGroup, self).__init__(region_name, tags) 358 self.ec2_backend = ec2_backend 359 self.cluster_subnet_group_name = cluster_subnet_group_name 360 self.description = description 361 self.subnet_ids = subnet_ids 362 if not self.subnets: 363 raise InvalidSubnetError(subnet_ids) 364 365 @staticmethod 366 def cloudformation_name_type(): 367 return None 368 369 @staticmethod 370 def cloudformation_type(): 371 # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-clustersubnetgroup.html 372 return "AWS::Redshift::ClusterSubnetGroup" 373 374 @classmethod 375 def create_from_cloudformation_json( 376 cls, resource_name, cloudformation_json, region_name, **kwargs 377 ): 378 redshift_backend = redshift_backends[region_name] 379 properties = cloudformation_json["Properties"] 380 381 subnet_group = redshift_backend.create_cluster_subnet_group( 382 cluster_subnet_group_name=resource_name, 383 description=properties.get("Description"), 384 subnet_ids=properties.get("SubnetIds", []), 385 region_name=region_name, 386 ) 387 return subnet_group 388 389 @property 390 def subnets(self): 391 return self.ec2_backend.get_all_subnets(filters={"subnet-id": self.subnet_ids}) 392 393 @property 394 def vpc_id(self): 395 return self.subnets[0].vpc_id 396 397 @property 398 def resource_id(self): 399 return self.cluster_subnet_group_name 400 401 def to_json(self): 402 return { 403 "VpcId": self.vpc_id, 404 "Description": self.description, 405 "ClusterSubnetGroupName": self.cluster_subnet_group_name, 406 "SubnetGroupStatus": "Complete", 407 "Subnets": [ 408 { 409 "SubnetStatus": "Active", 410 "SubnetIdentifier": subnet.id, 411 "SubnetAvailabilityZone": {"Name": subnet.availability_zone}, 412 } 413 for subnet in self.subnets 414 ], 415 "Tags": self.tags, 416 } 417 418 419class SecurityGroup(TaggableResourceMixin, BaseModel): 420 421 resource_type = "securitygroup" 422 423 def __init__( 424 self, cluster_security_group_name, description, region_name, tags=None 425 ): 426 super(SecurityGroup, self).__init__(region_name, tags) 427 self.cluster_security_group_name = cluster_security_group_name 428 self.description = description 429 self.ingress_rules = [] 430 431 @property 432 def resource_id(self): 433 return self.cluster_security_group_name 434 435 def to_json(self): 436 return { 437 "EC2SecurityGroups": [], 438 "IPRanges": [], 439 "Description": self.description, 440 "ClusterSecurityGroupName": self.cluster_security_group_name, 441 "Tags": self.tags, 442 } 443 444 445class ParameterGroup(TaggableResourceMixin, CloudFormationModel): 446 447 resource_type = "parametergroup" 448 449 def __init__( 450 self, 451 cluster_parameter_group_name, 452 group_family, 453 description, 454 region_name, 455 tags=None, 456 ): 457 super(ParameterGroup, self).__init__(region_name, tags) 458 self.cluster_parameter_group_name = cluster_parameter_group_name 459 self.group_family = group_family 460 self.description = description 461 462 @staticmethod 463 def cloudformation_name_type(): 464 return None 465 466 @staticmethod 467 def cloudformation_type(): 468 # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-clusterparametergroup.html 469 return "AWS::Redshift::ClusterParameterGroup" 470 471 @classmethod 472 def create_from_cloudformation_json( 473 cls, resource_name, cloudformation_json, region_name, **kwargs 474 ): 475 redshift_backend = redshift_backends[region_name] 476 properties = cloudformation_json["Properties"] 477 478 parameter_group = redshift_backend.create_cluster_parameter_group( 479 cluster_parameter_group_name=resource_name, 480 description=properties.get("Description"), 481 group_family=properties.get("ParameterGroupFamily"), 482 region_name=region_name, 483 ) 484 return parameter_group 485 486 @property 487 def resource_id(self): 488 return self.cluster_parameter_group_name 489 490 def to_json(self): 491 return { 492 "ParameterGroupFamily": self.group_family, 493 "Description": self.description, 494 "ParameterGroupName": self.cluster_parameter_group_name, 495 "Tags": self.tags, 496 } 497 498 499class Snapshot(TaggableResourceMixin, BaseModel): 500 501 resource_type = "snapshot" 502 503 def __init__( 504 self, cluster, snapshot_identifier, region_name, tags=None, iam_roles_arn=None 505 ): 506 super(Snapshot, self).__init__(region_name, tags) 507 self.cluster = copy.copy(cluster) 508 self.snapshot_identifier = snapshot_identifier 509 self.snapshot_type = "manual" 510 self.status = "available" 511 self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) 512 self.iam_roles_arn = iam_roles_arn or [] 513 514 @property 515 def resource_id(self): 516 return "{cluster_id}/{snapshot_id}".format( 517 cluster_id=self.cluster.cluster_identifier, 518 snapshot_id=self.snapshot_identifier, 519 ) 520 521 def to_json(self): 522 return { 523 "SnapshotIdentifier": self.snapshot_identifier, 524 "ClusterIdentifier": self.cluster.cluster_identifier, 525 "SnapshotCreateTime": self.create_time, 526 "Status": self.status, 527 "Port": self.cluster.port, 528 "AvailabilityZone": self.cluster.availability_zone, 529 "MasterUsername": self.cluster.master_username, 530 "ClusterVersion": self.cluster.cluster_version, 531 "SnapshotType": self.snapshot_type, 532 "NodeType": self.cluster.node_type, 533 "NumberOfNodes": self.cluster.number_of_nodes, 534 "DBName": self.cluster.db_name, 535 "Tags": self.tags, 536 "EnhancedVpcRouting": self.cluster.enhanced_vpc_routing, 537 "IamRoles": [ 538 {"ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn} 539 for iam_role_arn in self.iam_roles_arn 540 ], 541 } 542 543 544class RedshiftBackend(BaseBackend): 545 def __init__(self, ec2_backend, region_name): 546 self.region = region_name 547 self.clusters = {} 548 self.subnet_groups = {} 549 self.security_groups = { 550 "Default": SecurityGroup( 551 "Default", "Default Redshift Security Group", self.region 552 ) 553 } 554 self.parameter_groups = { 555 "default.redshift-1.0": ParameterGroup( 556 "default.redshift-1.0", 557 "redshift-1.0", 558 "Default Redshift parameter group", 559 self.region, 560 ) 561 } 562 self.ec2_backend = ec2_backend 563 self.snapshots = OrderedDict() 564 self.RESOURCE_TYPE_MAP = { 565 "cluster": self.clusters, 566 "parametergroup": self.parameter_groups, 567 "securitygroup": self.security_groups, 568 "snapshot": self.snapshots, 569 "subnetgroup": self.subnet_groups, 570 } 571 self.snapshot_copy_grants = {} 572 573 def reset(self): 574 ec2_backend = self.ec2_backend 575 region_name = self.region 576 self.__dict__ = {} 577 self.__init__(ec2_backend, region_name) 578 579 @staticmethod 580 def default_vpc_endpoint_service(service_region, zones): 581 """Default VPC endpoint service.""" 582 return BaseBackend.default_vpc_endpoint_service_factory( 583 service_region, zones, "redshift" 584 ) + BaseBackend.default_vpc_endpoint_service_factory( 585 service_region, zones, "redshift-data", policy_supported=False 586 ) 587 588 def enable_snapshot_copy(self, **kwargs): 589 cluster_identifier = kwargs["cluster_identifier"] 590 cluster = self.clusters[cluster_identifier] 591 if not hasattr(cluster, "cluster_snapshot_copy_status"): 592 if ( 593 cluster.encrypted == "true" 594 and kwargs["snapshot_copy_grant_name"] is None 595 ): 596 raise InvalidParameterValueError( 597 "SnapshotCopyGrantName is required for Snapshot Copy on KMS encrypted clusters." 598 ) 599 if kwargs["destination_region"] == self.region: 600 raise UnknownSnapshotCopyRegionFaultError( 601 "Invalid region {}".format(self.region) 602 ) 603 status = { 604 "DestinationRegion": kwargs["destination_region"], 605 "RetentionPeriod": kwargs["retention_period"], 606 "SnapshotCopyGrantName": kwargs["snapshot_copy_grant_name"], 607 } 608 cluster.cluster_snapshot_copy_status = status 609 return cluster 610 else: 611 raise SnapshotCopyAlreadyEnabledFaultError(cluster_identifier) 612 613 def disable_snapshot_copy(self, **kwargs): 614 cluster_identifier = kwargs["cluster_identifier"] 615 cluster = self.clusters[cluster_identifier] 616 if hasattr(cluster, "cluster_snapshot_copy_status"): 617 del cluster.cluster_snapshot_copy_status 618 return cluster 619 else: 620 raise SnapshotCopyAlreadyDisabledFaultError(cluster_identifier) 621 622 def modify_snapshot_copy_retention_period( 623 self, cluster_identifier, retention_period 624 ): 625 cluster = self.clusters[cluster_identifier] 626 if hasattr(cluster, "cluster_snapshot_copy_status"): 627 cluster.cluster_snapshot_copy_status["RetentionPeriod"] = retention_period 628 return cluster 629 else: 630 raise SnapshotCopyDisabledFaultError(cluster_identifier) 631 632 def create_cluster(self, **cluster_kwargs): 633 cluster_identifier = cluster_kwargs["cluster_identifier"] 634 if cluster_identifier in self.clusters: 635 raise ClusterAlreadyExistsFaultError() 636 cluster = Cluster(self, **cluster_kwargs) 637 self.clusters[cluster_identifier] = cluster 638 return cluster 639 640 def describe_clusters(self, cluster_identifier=None): 641 clusters = self.clusters.values() 642 if cluster_identifier: 643 if cluster_identifier in self.clusters: 644 return [self.clusters[cluster_identifier]] 645 else: 646 raise ClusterNotFoundError(cluster_identifier) 647 return clusters 648 649 def modify_cluster(self, **cluster_kwargs): 650 cluster_identifier = cluster_kwargs.pop("cluster_identifier") 651 new_cluster_identifier = cluster_kwargs.pop("new_cluster_identifier", None) 652 653 cluster_type = cluster_kwargs.get("cluster_type") 654 if cluster_type and cluster_type not in ["multi-node", "single-node"]: 655 raise InvalidParameterValueError( 656 "Invalid cluster type. Cluster type can be one of multi-node or single-node" 657 ) 658 if cluster_type == "single-node": 659 # AWS will always silently override this value for single-node clusters. 660 cluster_kwargs["number_of_nodes"] = 1 661 elif cluster_type == "multi-node": 662 if cluster_kwargs.get("number_of_nodes", 0) < 2: 663 raise InvalidParameterCombinationError( 664 "Number of nodes for cluster type multi-node must be greater than or equal to 2" 665 ) 666 667 cluster = self.describe_clusters(cluster_identifier)[0] 668 669 for key, value in cluster_kwargs.items(): 670 setattr(cluster, key, value) 671 672 if new_cluster_identifier: 673 dic = { 674 "cluster_identifier": cluster_identifier, 675 "skip_final_snapshot": True, 676 "final_cluster_snapshot_identifier": None, 677 } 678 self.delete_cluster(**dic) 679 cluster.cluster_identifier = new_cluster_identifier 680 self.clusters[new_cluster_identifier] = cluster 681 682 return cluster 683 684 def delete_cluster(self, **cluster_kwargs): 685 cluster_identifier = cluster_kwargs.pop("cluster_identifier") 686 cluster_skip_final_snapshot = cluster_kwargs.pop("skip_final_snapshot") 687 cluster_snapshot_identifer = cluster_kwargs.pop( 688 "final_cluster_snapshot_identifier" 689 ) 690 691 if cluster_identifier in self.clusters: 692 if ( 693 cluster_skip_final_snapshot is False 694 and cluster_snapshot_identifer is None 695 ): 696 raise InvalidParameterCombinationError( 697 "FinalClusterSnapshotIdentifier is required unless SkipFinalClusterSnapshot is specified." 698 ) 699 elif ( 700 cluster_skip_final_snapshot is False 701 and cluster_snapshot_identifer is not None 702 ): # create snapshot 703 cluster = self.describe_clusters(cluster_identifier)[0] 704 self.create_cluster_snapshot( 705 cluster_identifier, 706 cluster_snapshot_identifer, 707 cluster.region, 708 cluster.tags, 709 ) 710 711 return self.clusters.pop(cluster_identifier) 712 raise ClusterNotFoundError(cluster_identifier) 713 714 def create_cluster_subnet_group( 715 self, cluster_subnet_group_name, description, subnet_ids, region_name, tags=None 716 ): 717 subnet_group = SubnetGroup( 718 self.ec2_backend, 719 cluster_subnet_group_name, 720 description, 721 subnet_ids, 722 region_name, 723 tags, 724 ) 725 self.subnet_groups[cluster_subnet_group_name] = subnet_group 726 return subnet_group 727 728 def describe_cluster_subnet_groups(self, subnet_identifier=None): 729 subnet_groups = self.subnet_groups.values() 730 if subnet_identifier: 731 if subnet_identifier in self.subnet_groups: 732 return [self.subnet_groups[subnet_identifier]] 733 else: 734 raise ClusterSubnetGroupNotFoundError(subnet_identifier) 735 return subnet_groups 736 737 def delete_cluster_subnet_group(self, subnet_identifier): 738 if subnet_identifier in self.subnet_groups: 739 return self.subnet_groups.pop(subnet_identifier) 740 raise ClusterSubnetGroupNotFoundError(subnet_identifier) 741 742 def create_cluster_security_group( 743 self, cluster_security_group_name, description, region_name, tags=None 744 ): 745 security_group = SecurityGroup( 746 cluster_security_group_name, description, region_name, tags 747 ) 748 self.security_groups[cluster_security_group_name] = security_group 749 return security_group 750 751 def describe_cluster_security_groups(self, security_group_name=None): 752 security_groups = self.security_groups.values() 753 if security_group_name: 754 if security_group_name in self.security_groups: 755 return [self.security_groups[security_group_name]] 756 else: 757 raise ClusterSecurityGroupNotFoundError(security_group_name) 758 return security_groups 759 760 def delete_cluster_security_group(self, security_group_identifier): 761 if security_group_identifier in self.security_groups: 762 return self.security_groups.pop(security_group_identifier) 763 raise ClusterSecurityGroupNotFoundError(security_group_identifier) 764 765 def authorize_cluster_security_group_ingress(self, security_group_name, cidr_ip): 766 security_group = self.security_groups.get(security_group_name) 767 if not security_group: 768 raise ClusterSecurityGroupNotFoundFaultError() 769 770 # just adding the cidr_ip as ingress rule for now as there is no security rule 771 security_group.ingress_rules.append(cidr_ip) 772 773 return security_group 774 775 def create_cluster_parameter_group( 776 self, 777 cluster_parameter_group_name, 778 group_family, 779 description, 780 region_name, 781 tags=None, 782 ): 783 parameter_group = ParameterGroup( 784 cluster_parameter_group_name, group_family, description, region_name, tags 785 ) 786 self.parameter_groups[cluster_parameter_group_name] = parameter_group 787 788 return parameter_group 789 790 def describe_cluster_parameter_groups(self, parameter_group_name=None): 791 parameter_groups = self.parameter_groups.values() 792 if parameter_group_name: 793 if parameter_group_name in self.parameter_groups: 794 return [self.parameter_groups[parameter_group_name]] 795 else: 796 raise ClusterParameterGroupNotFoundError(parameter_group_name) 797 return parameter_groups 798 799 def delete_cluster_parameter_group(self, parameter_group_name): 800 if parameter_group_name in self.parameter_groups: 801 return self.parameter_groups.pop(parameter_group_name) 802 raise ClusterParameterGroupNotFoundError(parameter_group_name) 803 804 def create_cluster_snapshot( 805 self, cluster_identifier, snapshot_identifier, region_name, tags 806 ): 807 cluster = self.clusters.get(cluster_identifier) 808 if not cluster: 809 raise ClusterNotFoundError(cluster_identifier) 810 if self.snapshots.get(snapshot_identifier) is not None: 811 raise ClusterSnapshotAlreadyExistsError(snapshot_identifier) 812 snapshot = Snapshot(cluster, snapshot_identifier, region_name, tags) 813 self.snapshots[snapshot_identifier] = snapshot 814 return snapshot 815 816 def describe_cluster_snapshots( 817 self, cluster_identifier=None, snapshot_identifier=None 818 ): 819 if cluster_identifier: 820 cluster_snapshots = [] 821 for snapshot in self.snapshots.values(): 822 if snapshot.cluster.cluster_identifier == cluster_identifier: 823 cluster_snapshots.append(snapshot) 824 if cluster_snapshots: 825 return cluster_snapshots 826 827 if snapshot_identifier: 828 if snapshot_identifier in self.snapshots: 829 return [self.snapshots[snapshot_identifier]] 830 raise ClusterSnapshotNotFoundError(snapshot_identifier) 831 832 return self.snapshots.values() 833 834 def delete_cluster_snapshot(self, snapshot_identifier): 835 if snapshot_identifier not in self.snapshots: 836 raise ClusterSnapshotNotFoundError(snapshot_identifier) 837 838 deleted_snapshot = self.snapshots.pop(snapshot_identifier) 839 deleted_snapshot.status = "deleted" 840 return deleted_snapshot 841 842 def restore_from_cluster_snapshot(self, **kwargs): 843 snapshot_identifier = kwargs.pop("snapshot_identifier") 844 snapshot = self.describe_cluster_snapshots( 845 snapshot_identifier=snapshot_identifier 846 )[0] 847 create_kwargs = { 848 "node_type": snapshot.cluster.node_type, 849 "master_username": snapshot.cluster.master_username, 850 "master_user_password": snapshot.cluster.master_user_password, 851 "db_name": snapshot.cluster.db_name, 852 "cluster_type": "multi-node" 853 if snapshot.cluster.number_of_nodes > 1 854 else "single-node", 855 "availability_zone": snapshot.cluster.availability_zone, 856 "port": snapshot.cluster.port, 857 "cluster_version": snapshot.cluster.cluster_version, 858 "number_of_nodes": snapshot.cluster.number_of_nodes, 859 "encrypted": snapshot.cluster.encrypted, 860 "tags": snapshot.cluster.tags, 861 "restored_from_snapshot": True, 862 "enhanced_vpc_routing": snapshot.cluster.enhanced_vpc_routing, 863 } 864 create_kwargs.update(kwargs) 865 return self.create_cluster(**create_kwargs) 866 867 def create_snapshot_copy_grant(self, **kwargs): 868 snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] 869 kms_key_id = kwargs["kms_key_id"] 870 if snapshot_copy_grant_name not in self.snapshot_copy_grants: 871 snapshot_copy_grant = SnapshotCopyGrant( 872 snapshot_copy_grant_name, kms_key_id 873 ) 874 self.snapshot_copy_grants[snapshot_copy_grant_name] = snapshot_copy_grant 875 return snapshot_copy_grant 876 raise SnapshotCopyGrantAlreadyExistsFaultError(snapshot_copy_grant_name) 877 878 def delete_snapshot_copy_grant(self, **kwargs): 879 snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] 880 if snapshot_copy_grant_name in self.snapshot_copy_grants: 881 return self.snapshot_copy_grants.pop(snapshot_copy_grant_name) 882 raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name) 883 884 def describe_snapshot_copy_grants(self, **kwargs): 885 copy_grants = self.snapshot_copy_grants.values() 886 snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] 887 if snapshot_copy_grant_name: 888 if snapshot_copy_grant_name in self.snapshot_copy_grants: 889 return [self.snapshot_copy_grants[snapshot_copy_grant_name]] 890 else: 891 raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name) 892 return copy_grants 893 894 def _get_resource_from_arn(self, arn): 895 try: 896 arn_breakdown = arn.split(":") 897 resource_type = arn_breakdown[5] 898 if resource_type == "snapshot": 899 resource_id = arn_breakdown[6].split("/")[1] 900 else: 901 resource_id = arn_breakdown[6] 902 except IndexError: 903 resource_type = resource_id = arn 904 resources = self.RESOURCE_TYPE_MAP.get(resource_type) 905 if resources is None: 906 message = ( 907 "Tagging is not supported for this type of resource: '{0}' " 908 "(the ARN is potentially malformed, please check the ARN " 909 "documentation for more information)".format(resource_type) 910 ) 911 raise ResourceNotFoundFaultError(message=message) 912 try: 913 resource = resources[resource_id] 914 except KeyError: 915 raise ResourceNotFoundFaultError(resource_type, resource_id) 916 else: 917 return resource 918 919 @staticmethod 920 def _describe_tags_for_resources(resources): 921 tagged_resources = [] 922 for resource in resources: 923 for tag in resource.tags: 924 data = { 925 "ResourceName": resource.arn, 926 "ResourceType": resource.resource_type, 927 "Tag": {"Key": tag["Key"], "Value": tag["Value"]}, 928 } 929 tagged_resources.append(data) 930 return tagged_resources 931 932 def _describe_tags_for_resource_type(self, resource_type): 933 resources = self.RESOURCE_TYPE_MAP.get(resource_type) 934 if not resources: 935 raise ResourceNotFoundFaultError(resource_type=resource_type) 936 return self._describe_tags_for_resources(resources.values()) 937 938 def _describe_tags_for_resource_name(self, resource_name): 939 resource = self._get_resource_from_arn(resource_name) 940 return self._describe_tags_for_resources([resource]) 941 942 def create_tags(self, resource_name, tags): 943 resource = self._get_resource_from_arn(resource_name) 944 resource.create_tags(tags) 945 946 def describe_tags(self, resource_name, resource_type): 947 if resource_name and resource_type: 948 raise InvalidParameterValueError( 949 "You cannot filter a list of resources using an Amazon " 950 "Resource Name (ARN) and a resource type together in the " 951 "same request. Retry the request using either an ARN or " 952 "a resource type, but not both." 953 ) 954 if resource_type: 955 return self._describe_tags_for_resource_type(resource_type.lower()) 956 if resource_name: 957 return self._describe_tags_for_resource_name(resource_name) 958 # If name and type are not specified, return all tagged resources. 959 # TODO: Implement aws marker pagination 960 tagged_resources = [] 961 for resource_type in self.RESOURCE_TYPE_MAP: 962 try: 963 tagged_resources += self._describe_tags_for_resource_type(resource_type) 964 except ResourceNotFoundFaultError: 965 pass 966 return tagged_resources 967 968 def delete_tags(self, resource_name, tag_keys): 969 resource = self._get_resource_from_arn(resource_name) 970 resource.delete_tags(tag_keys) 971 972 def get_cluster_credentials( 973 self, cluster_identifier, db_user, auto_create, duration_seconds 974 ): 975 if duration_seconds < 900 or duration_seconds > 3600: 976 raise InvalidParameterValueError( 977 "Token duration must be between 900 and 3600 seconds" 978 ) 979 expiration = datetime.datetime.now() + datetime.timedelta(0, duration_seconds) 980 if cluster_identifier in self.clusters: 981 user_prefix = "IAM:" if auto_create is False else "IAMA:" 982 db_user = user_prefix + db_user 983 return { 984 "DbUser": db_user, 985 "DbPassword": random_string(32), 986 "Expiration": expiration, 987 } 988 else: 989 raise ClusterNotFoundError(cluster_identifier) 990 991 992redshift_backends = {} 993for region in Session().get_available_regions("redshift"): 994 redshift_backends[region] = RedshiftBackend(ec2_backends[region], region) 995for region in Session().get_available_regions("redshift", partition_name="aws-us-gov"): 996 redshift_backends[region] = RedshiftBackend(ec2_backends[region], region) 997for region in Session().get_available_regions("redshift", partition_name="aws-cn"): 998 redshift_backends[region] = RedshiftBackend(ec2_backends[region], region) 999