1import copy
2import datetime
3
4from boto3 import Session
5
6from collections import OrderedDict
7from moto.core import BaseBackend, BaseModel, CloudFormationModel
8from moto.core.utils import iso_8601_datetime_with_milliseconds
9from moto.utilities.utils import random_string
10from moto.ec2 import ec2_backends
11from .exceptions import (
12    ClusterAlreadyExistsFaultError,
13    ClusterNotFoundError,
14    ClusterParameterGroupNotFoundError,
15    ClusterSecurityGroupNotFoundError,
16    ClusterSnapshotAlreadyExistsError,
17    ClusterSnapshotNotFoundError,
18    ClusterSubnetGroupNotFoundError,
19    InvalidParameterCombinationError,
20    InvalidParameterValueError,
21    InvalidSubnetError,
22    ResourceNotFoundFaultError,
23    SnapshotCopyAlreadyDisabledFaultError,
24    SnapshotCopyAlreadyEnabledFaultError,
25    SnapshotCopyDisabledFaultError,
26    SnapshotCopyGrantAlreadyExistsFaultError,
27    SnapshotCopyGrantNotFoundFaultError,
28    UnknownSnapshotCopyRegionFaultError,
29    ClusterSecurityGroupNotFoundFaultError,
30)
31
32
33from moto.core import ACCOUNT_ID
34
35
36class TaggableResourceMixin(object):
37
38    resource_type = None
39
40    def __init__(self, region_name, tags):
41        self.region = region_name
42        self.tags = tags or []
43
44    @property
45    def resource_id(self):
46        return None
47
48    @property
49    def arn(self):
50        return "arn:aws:redshift:{region}:{account_id}:{resource_type}:{resource_id}".format(
51            region=self.region,
52            account_id=ACCOUNT_ID,
53            resource_type=self.resource_type,
54            resource_id=self.resource_id,
55        )
56
57    def create_tags(self, tags):
58        new_keys = [tag_set["Key"] for tag_set in tags]
59        self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
60        self.tags.extend(tags)
61        return self.tags
62
63    def delete_tags(self, tag_keys):
64        self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]
65        return self.tags
66
67
68class Cluster(TaggableResourceMixin, CloudFormationModel):
69
70    resource_type = "cluster"
71
72    def __init__(
73        self,
74        redshift_backend,
75        cluster_identifier,
76        node_type,
77        master_username,
78        master_user_password,
79        db_name,
80        cluster_type,
81        cluster_security_groups,
82        vpc_security_group_ids,
83        cluster_subnet_group_name,
84        availability_zone,
85        preferred_maintenance_window,
86        cluster_parameter_group_name,
87        automated_snapshot_retention_period,
88        port,
89        cluster_version,
90        allow_version_upgrade,
91        number_of_nodes,
92        publicly_accessible,
93        encrypted,
94        region_name,
95        tags=None,
96        iam_roles_arn=None,
97        enhanced_vpc_routing=None,
98        restored_from_snapshot=False,
99        kms_key_id=None,
100    ):
101        super(Cluster, self).__init__(region_name, tags)
102        self.redshift_backend = redshift_backend
103        self.cluster_identifier = cluster_identifier
104        self.create_time = iso_8601_datetime_with_milliseconds(
105            datetime.datetime.utcnow()
106        )
107        self.status = "available"
108        self.node_type = node_type
109        self.master_username = master_username
110        self.master_user_password = master_user_password
111        self.db_name = db_name if db_name else "dev"
112        self.vpc_security_group_ids = vpc_security_group_ids
113        self.enhanced_vpc_routing = (
114            enhanced_vpc_routing if enhanced_vpc_routing is not None else False
115        )
116        self.cluster_subnet_group_name = cluster_subnet_group_name
117        self.publicly_accessible = publicly_accessible
118        self.encrypted = encrypted
119
120        self.allow_version_upgrade = (
121            allow_version_upgrade if allow_version_upgrade is not None else True
122        )
123        self.cluster_version = cluster_version if cluster_version else "1.0"
124        self.port = int(port) if port else 5439
125        self.automated_snapshot_retention_period = (
126            int(automated_snapshot_retention_period)
127            if automated_snapshot_retention_period
128            else 1
129        )
130        self.preferred_maintenance_window = (
131            preferred_maintenance_window
132            if preferred_maintenance_window
133            else "Mon:03:00-Mon:03:30"
134        )
135
136        if cluster_parameter_group_name:
137            self.cluster_parameter_group_name = [cluster_parameter_group_name]
138        else:
139            self.cluster_parameter_group_name = ["default.redshift-1.0"]
140
141        if cluster_security_groups:
142            self.cluster_security_groups = cluster_security_groups
143        else:
144            self.cluster_security_groups = ["Default"]
145
146        if availability_zone:
147            self.availability_zone = availability_zone
148        else:
149            # This could probably be smarter, but there doesn't appear to be a
150            # way to pull AZs for a region in boto
151            self.availability_zone = region_name + "a"
152
153        if cluster_type == "single-node":
154            self.number_of_nodes = 1
155        elif number_of_nodes:
156            self.number_of_nodes = int(number_of_nodes)
157        else:
158            self.number_of_nodes = 1
159
160        self.iam_roles_arn = iam_roles_arn or []
161        self.restored_from_snapshot = restored_from_snapshot
162        self.kms_key_id = kms_key_id
163
164    @staticmethod
165    def cloudformation_name_type():
166        return None
167
168    @staticmethod
169    def cloudformation_type():
170        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-cluster.html
171        return "AWS::Redshift::Cluster"
172
173    @classmethod
174    def create_from_cloudformation_json(
175        cls, resource_name, cloudformation_json, region_name, **kwargs
176    ):
177        redshift_backend = redshift_backends[region_name]
178        properties = cloudformation_json["Properties"]
179
180        if "ClusterSubnetGroupName" in properties:
181            subnet_group_name = properties[
182                "ClusterSubnetGroupName"
183            ].cluster_subnet_group_name
184        else:
185            subnet_group_name = None
186
187        cluster = redshift_backend.create_cluster(
188            cluster_identifier=resource_name,
189            node_type=properties.get("NodeType"),
190            master_username=properties.get("MasterUsername"),
191            master_user_password=properties.get("MasterUserPassword"),
192            db_name=properties.get("DBName"),
193            cluster_type=properties.get("ClusterType"),
194            cluster_security_groups=properties.get("ClusterSecurityGroups", []),
195            vpc_security_group_ids=properties.get("VpcSecurityGroupIds", []),
196            cluster_subnet_group_name=subnet_group_name,
197            availability_zone=properties.get("AvailabilityZone"),
198            preferred_maintenance_window=properties.get("PreferredMaintenanceWindow"),
199            cluster_parameter_group_name=properties.get("ClusterParameterGroupName"),
200            automated_snapshot_retention_period=properties.get(
201                "AutomatedSnapshotRetentionPeriod"
202            ),
203            port=properties.get("Port"),
204            cluster_version=properties.get("ClusterVersion"),
205            allow_version_upgrade=properties.get("AllowVersionUpgrade"),
206            enhanced_vpc_routing=properties.get("EnhancedVpcRouting"),
207            number_of_nodes=properties.get("NumberOfNodes"),
208            publicly_accessible=properties.get("PubliclyAccessible"),
209            encrypted=properties.get("Encrypted"),
210            region_name=region_name,
211            kms_key_id=properties.get("KmsKeyId"),
212        )
213        return cluster
214
215    @classmethod
216    def has_cfn_attr(cls, attribute):
217        return attribute in ["Endpoint.Address", "Endpoint.Port"]
218
219    def get_cfn_attribute(self, attribute_name):
220        from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
221
222        if attribute_name == "Endpoint.Address":
223            return self.endpoint
224        elif attribute_name == "Endpoint.Port":
225            return self.port
226        raise UnformattedGetAttTemplateException()
227
228    @property
229    def endpoint(self):
230        return "{0}.cg034hpkmmjt.{1}.redshift.amazonaws.com".format(
231            self.cluster_identifier, self.region
232        )
233
234    @property
235    def security_groups(self):
236        return [
237            security_group
238            for security_group in self.redshift_backend.describe_cluster_security_groups()
239            if security_group.cluster_security_group_name
240            in self.cluster_security_groups
241        ]
242
243    @property
244    def vpc_security_groups(self):
245        return [
246            security_group
247            for security_group in self.redshift_backend.ec2_backend.describe_security_groups()
248            if security_group.id in self.vpc_security_group_ids
249        ]
250
251    @property
252    def parameter_groups(self):
253        return [
254            parameter_group
255            for parameter_group in self.redshift_backend.describe_cluster_parameter_groups()
256            if parameter_group.cluster_parameter_group_name
257            in self.cluster_parameter_group_name
258        ]
259
260    @property
261    def resource_id(self):
262        return self.cluster_identifier
263
264    def to_json(self):
265        json_response = {
266            "MasterUsername": self.master_username,
267            "MasterUserPassword": "****",
268            "ClusterVersion": self.cluster_version,
269            "VpcSecurityGroups": [
270                {"Status": "active", "VpcSecurityGroupId": group.id}
271                for group in self.vpc_security_groups
272            ],
273            "ClusterSubnetGroupName": self.cluster_subnet_group_name,
274            "AvailabilityZone": self.availability_zone,
275            "ClusterStatus": self.status,
276            "NumberOfNodes": self.number_of_nodes,
277            "AutomatedSnapshotRetentionPeriod": self.automated_snapshot_retention_period,
278            "PubliclyAccessible": self.publicly_accessible,
279            "Encrypted": self.encrypted,
280            "DBName": self.db_name,
281            "PreferredMaintenanceWindow": self.preferred_maintenance_window,
282            "ClusterParameterGroups": [
283                {
284                    "ParameterApplyStatus": "in-sync",
285                    "ParameterGroupName": group.cluster_parameter_group_name,
286                }
287                for group in self.parameter_groups
288            ],
289            "ClusterSecurityGroups": [
290                {
291                    "Status": "active",
292                    "ClusterSecurityGroupName": group.cluster_security_group_name,
293                }
294                for group in self.security_groups
295            ],
296            "Port": self.port,
297            "NodeType": self.node_type,
298            "ClusterIdentifier": self.cluster_identifier,
299            "AllowVersionUpgrade": self.allow_version_upgrade,
300            "Endpoint": {"Address": self.endpoint, "Port": self.port},
301            "ClusterCreateTime": self.create_time,
302            "PendingModifiedValues": [],
303            "Tags": self.tags,
304            "EnhancedVpcRouting": self.enhanced_vpc_routing,
305            "IamRoles": [
306                {"ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn}
307                for iam_role_arn in self.iam_roles_arn
308            ],
309            "KmsKeyId": self.kms_key_id,
310        }
311        if self.restored_from_snapshot:
312            json_response["RestoreStatus"] = {
313                "Status": "completed",
314                "CurrentRestoreRateInMegaBytesPerSecond": 123.0,
315                "SnapshotSizeInMegaBytes": 123,
316                "ProgressInMegaBytes": 123,
317                "ElapsedTimeInSeconds": 123,
318                "EstimatedTimeToCompletionInSeconds": 123,
319            }
320        try:
321            json_response[
322                "ClusterSnapshotCopyStatus"
323            ] = self.cluster_snapshot_copy_status
324        except AttributeError:
325            pass
326        return json_response
327
328
329class SnapshotCopyGrant(TaggableResourceMixin, BaseModel):
330
331    resource_type = "snapshotcopygrant"
332
333    def __init__(self, snapshot_copy_grant_name, kms_key_id):
334        self.snapshot_copy_grant_name = snapshot_copy_grant_name
335        self.kms_key_id = kms_key_id
336
337    def to_json(self):
338        return {
339            "SnapshotCopyGrantName": self.snapshot_copy_grant_name,
340            "KmsKeyId": self.kms_key_id,
341        }
342
343
344class SubnetGroup(TaggableResourceMixin, CloudFormationModel):
345
346    resource_type = "subnetgroup"
347
348    def __init__(
349        self,
350        ec2_backend,
351        cluster_subnet_group_name,
352        description,
353        subnet_ids,
354        region_name,
355        tags=None,
356    ):
357        super(SubnetGroup, self).__init__(region_name, tags)
358        self.ec2_backend = ec2_backend
359        self.cluster_subnet_group_name = cluster_subnet_group_name
360        self.description = description
361        self.subnet_ids = subnet_ids
362        if not self.subnets:
363            raise InvalidSubnetError(subnet_ids)
364
365    @staticmethod
366    def cloudformation_name_type():
367        return None
368
369    @staticmethod
370    def cloudformation_type():
371        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-clustersubnetgroup.html
372        return "AWS::Redshift::ClusterSubnetGroup"
373
374    @classmethod
375    def create_from_cloudformation_json(
376        cls, resource_name, cloudformation_json, region_name, **kwargs
377    ):
378        redshift_backend = redshift_backends[region_name]
379        properties = cloudformation_json["Properties"]
380
381        subnet_group = redshift_backend.create_cluster_subnet_group(
382            cluster_subnet_group_name=resource_name,
383            description=properties.get("Description"),
384            subnet_ids=properties.get("SubnetIds", []),
385            region_name=region_name,
386        )
387        return subnet_group
388
389    @property
390    def subnets(self):
391        return self.ec2_backend.get_all_subnets(filters={"subnet-id": self.subnet_ids})
392
393    @property
394    def vpc_id(self):
395        return self.subnets[0].vpc_id
396
397    @property
398    def resource_id(self):
399        return self.cluster_subnet_group_name
400
401    def to_json(self):
402        return {
403            "VpcId": self.vpc_id,
404            "Description": self.description,
405            "ClusterSubnetGroupName": self.cluster_subnet_group_name,
406            "SubnetGroupStatus": "Complete",
407            "Subnets": [
408                {
409                    "SubnetStatus": "Active",
410                    "SubnetIdentifier": subnet.id,
411                    "SubnetAvailabilityZone": {"Name": subnet.availability_zone},
412                }
413                for subnet in self.subnets
414            ],
415            "Tags": self.tags,
416        }
417
418
419class SecurityGroup(TaggableResourceMixin, BaseModel):
420
421    resource_type = "securitygroup"
422
423    def __init__(
424        self, cluster_security_group_name, description, region_name, tags=None
425    ):
426        super(SecurityGroup, self).__init__(region_name, tags)
427        self.cluster_security_group_name = cluster_security_group_name
428        self.description = description
429        self.ingress_rules = []
430
431    @property
432    def resource_id(self):
433        return self.cluster_security_group_name
434
435    def to_json(self):
436        return {
437            "EC2SecurityGroups": [],
438            "IPRanges": [],
439            "Description": self.description,
440            "ClusterSecurityGroupName": self.cluster_security_group_name,
441            "Tags": self.tags,
442        }
443
444
445class ParameterGroup(TaggableResourceMixin, CloudFormationModel):
446
447    resource_type = "parametergroup"
448
449    def __init__(
450        self,
451        cluster_parameter_group_name,
452        group_family,
453        description,
454        region_name,
455        tags=None,
456    ):
457        super(ParameterGroup, self).__init__(region_name, tags)
458        self.cluster_parameter_group_name = cluster_parameter_group_name
459        self.group_family = group_family
460        self.description = description
461
462    @staticmethod
463    def cloudformation_name_type():
464        return None
465
466    @staticmethod
467    def cloudformation_type():
468        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-redshift-clusterparametergroup.html
469        return "AWS::Redshift::ClusterParameterGroup"
470
471    @classmethod
472    def create_from_cloudformation_json(
473        cls, resource_name, cloudformation_json, region_name, **kwargs
474    ):
475        redshift_backend = redshift_backends[region_name]
476        properties = cloudformation_json["Properties"]
477
478        parameter_group = redshift_backend.create_cluster_parameter_group(
479            cluster_parameter_group_name=resource_name,
480            description=properties.get("Description"),
481            group_family=properties.get("ParameterGroupFamily"),
482            region_name=region_name,
483        )
484        return parameter_group
485
486    @property
487    def resource_id(self):
488        return self.cluster_parameter_group_name
489
490    def to_json(self):
491        return {
492            "ParameterGroupFamily": self.group_family,
493            "Description": self.description,
494            "ParameterGroupName": self.cluster_parameter_group_name,
495            "Tags": self.tags,
496        }
497
498
499class Snapshot(TaggableResourceMixin, BaseModel):
500
501    resource_type = "snapshot"
502
503    def __init__(
504        self, cluster, snapshot_identifier, region_name, tags=None, iam_roles_arn=None
505    ):
506        super(Snapshot, self).__init__(region_name, tags)
507        self.cluster = copy.copy(cluster)
508        self.snapshot_identifier = snapshot_identifier
509        self.snapshot_type = "manual"
510        self.status = "available"
511        self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
512        self.iam_roles_arn = iam_roles_arn or []
513
514    @property
515    def resource_id(self):
516        return "{cluster_id}/{snapshot_id}".format(
517            cluster_id=self.cluster.cluster_identifier,
518            snapshot_id=self.snapshot_identifier,
519        )
520
521    def to_json(self):
522        return {
523            "SnapshotIdentifier": self.snapshot_identifier,
524            "ClusterIdentifier": self.cluster.cluster_identifier,
525            "SnapshotCreateTime": self.create_time,
526            "Status": self.status,
527            "Port": self.cluster.port,
528            "AvailabilityZone": self.cluster.availability_zone,
529            "MasterUsername": self.cluster.master_username,
530            "ClusterVersion": self.cluster.cluster_version,
531            "SnapshotType": self.snapshot_type,
532            "NodeType": self.cluster.node_type,
533            "NumberOfNodes": self.cluster.number_of_nodes,
534            "DBName": self.cluster.db_name,
535            "Tags": self.tags,
536            "EnhancedVpcRouting": self.cluster.enhanced_vpc_routing,
537            "IamRoles": [
538                {"ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn}
539                for iam_role_arn in self.iam_roles_arn
540            ],
541        }
542
543
544class RedshiftBackend(BaseBackend):
545    def __init__(self, ec2_backend, region_name):
546        self.region = region_name
547        self.clusters = {}
548        self.subnet_groups = {}
549        self.security_groups = {
550            "Default": SecurityGroup(
551                "Default", "Default Redshift Security Group", self.region
552            )
553        }
554        self.parameter_groups = {
555            "default.redshift-1.0": ParameterGroup(
556                "default.redshift-1.0",
557                "redshift-1.0",
558                "Default Redshift parameter group",
559                self.region,
560            )
561        }
562        self.ec2_backend = ec2_backend
563        self.snapshots = OrderedDict()
564        self.RESOURCE_TYPE_MAP = {
565            "cluster": self.clusters,
566            "parametergroup": self.parameter_groups,
567            "securitygroup": self.security_groups,
568            "snapshot": self.snapshots,
569            "subnetgroup": self.subnet_groups,
570        }
571        self.snapshot_copy_grants = {}
572
573    def reset(self):
574        ec2_backend = self.ec2_backend
575        region_name = self.region
576        self.__dict__ = {}
577        self.__init__(ec2_backend, region_name)
578
579    @staticmethod
580    def default_vpc_endpoint_service(service_region, zones):
581        """Default VPC endpoint service."""
582        return BaseBackend.default_vpc_endpoint_service_factory(
583            service_region, zones, "redshift"
584        ) + BaseBackend.default_vpc_endpoint_service_factory(
585            service_region, zones, "redshift-data", policy_supported=False
586        )
587
588    def enable_snapshot_copy(self, **kwargs):
589        cluster_identifier = kwargs["cluster_identifier"]
590        cluster = self.clusters[cluster_identifier]
591        if not hasattr(cluster, "cluster_snapshot_copy_status"):
592            if (
593                cluster.encrypted == "true"
594                and kwargs["snapshot_copy_grant_name"] is None
595            ):
596                raise InvalidParameterValueError(
597                    "SnapshotCopyGrantName is required for Snapshot Copy on KMS encrypted clusters."
598                )
599            if kwargs["destination_region"] == self.region:
600                raise UnknownSnapshotCopyRegionFaultError(
601                    "Invalid region {}".format(self.region)
602                )
603            status = {
604                "DestinationRegion": kwargs["destination_region"],
605                "RetentionPeriod": kwargs["retention_period"],
606                "SnapshotCopyGrantName": kwargs["snapshot_copy_grant_name"],
607            }
608            cluster.cluster_snapshot_copy_status = status
609            return cluster
610        else:
611            raise SnapshotCopyAlreadyEnabledFaultError(cluster_identifier)
612
613    def disable_snapshot_copy(self, **kwargs):
614        cluster_identifier = kwargs["cluster_identifier"]
615        cluster = self.clusters[cluster_identifier]
616        if hasattr(cluster, "cluster_snapshot_copy_status"):
617            del cluster.cluster_snapshot_copy_status
618            return cluster
619        else:
620            raise SnapshotCopyAlreadyDisabledFaultError(cluster_identifier)
621
622    def modify_snapshot_copy_retention_period(
623        self, cluster_identifier, retention_period
624    ):
625        cluster = self.clusters[cluster_identifier]
626        if hasattr(cluster, "cluster_snapshot_copy_status"):
627            cluster.cluster_snapshot_copy_status["RetentionPeriod"] = retention_period
628            return cluster
629        else:
630            raise SnapshotCopyDisabledFaultError(cluster_identifier)
631
632    def create_cluster(self, **cluster_kwargs):
633        cluster_identifier = cluster_kwargs["cluster_identifier"]
634        if cluster_identifier in self.clusters:
635            raise ClusterAlreadyExistsFaultError()
636        cluster = Cluster(self, **cluster_kwargs)
637        self.clusters[cluster_identifier] = cluster
638        return cluster
639
640    def describe_clusters(self, cluster_identifier=None):
641        clusters = self.clusters.values()
642        if cluster_identifier:
643            if cluster_identifier in self.clusters:
644                return [self.clusters[cluster_identifier]]
645            else:
646                raise ClusterNotFoundError(cluster_identifier)
647        return clusters
648
649    def modify_cluster(self, **cluster_kwargs):
650        cluster_identifier = cluster_kwargs.pop("cluster_identifier")
651        new_cluster_identifier = cluster_kwargs.pop("new_cluster_identifier", None)
652
653        cluster_type = cluster_kwargs.get("cluster_type")
654        if cluster_type and cluster_type not in ["multi-node", "single-node"]:
655            raise InvalidParameterValueError(
656                "Invalid cluster type. Cluster type can be one of multi-node or single-node"
657            )
658        if cluster_type == "single-node":
659            # AWS will always silently override this value for single-node clusters.
660            cluster_kwargs["number_of_nodes"] = 1
661        elif cluster_type == "multi-node":
662            if cluster_kwargs.get("number_of_nodes", 0) < 2:
663                raise InvalidParameterCombinationError(
664                    "Number of nodes for cluster type multi-node must be greater than or equal to 2"
665                )
666
667        cluster = self.describe_clusters(cluster_identifier)[0]
668
669        for key, value in cluster_kwargs.items():
670            setattr(cluster, key, value)
671
672        if new_cluster_identifier:
673            dic = {
674                "cluster_identifier": cluster_identifier,
675                "skip_final_snapshot": True,
676                "final_cluster_snapshot_identifier": None,
677            }
678            self.delete_cluster(**dic)
679            cluster.cluster_identifier = new_cluster_identifier
680            self.clusters[new_cluster_identifier] = cluster
681
682        return cluster
683
684    def delete_cluster(self, **cluster_kwargs):
685        cluster_identifier = cluster_kwargs.pop("cluster_identifier")
686        cluster_skip_final_snapshot = cluster_kwargs.pop("skip_final_snapshot")
687        cluster_snapshot_identifer = cluster_kwargs.pop(
688            "final_cluster_snapshot_identifier"
689        )
690
691        if cluster_identifier in self.clusters:
692            if (
693                cluster_skip_final_snapshot is False
694                and cluster_snapshot_identifer is None
695            ):
696                raise InvalidParameterCombinationError(
697                    "FinalClusterSnapshotIdentifier is required unless SkipFinalClusterSnapshot is specified."
698                )
699            elif (
700                cluster_skip_final_snapshot is False
701                and cluster_snapshot_identifer is not None
702            ):  # create snapshot
703                cluster = self.describe_clusters(cluster_identifier)[0]
704                self.create_cluster_snapshot(
705                    cluster_identifier,
706                    cluster_snapshot_identifer,
707                    cluster.region,
708                    cluster.tags,
709                )
710
711            return self.clusters.pop(cluster_identifier)
712        raise ClusterNotFoundError(cluster_identifier)
713
714    def create_cluster_subnet_group(
715        self, cluster_subnet_group_name, description, subnet_ids, region_name, tags=None
716    ):
717        subnet_group = SubnetGroup(
718            self.ec2_backend,
719            cluster_subnet_group_name,
720            description,
721            subnet_ids,
722            region_name,
723            tags,
724        )
725        self.subnet_groups[cluster_subnet_group_name] = subnet_group
726        return subnet_group
727
728    def describe_cluster_subnet_groups(self, subnet_identifier=None):
729        subnet_groups = self.subnet_groups.values()
730        if subnet_identifier:
731            if subnet_identifier in self.subnet_groups:
732                return [self.subnet_groups[subnet_identifier]]
733            else:
734                raise ClusterSubnetGroupNotFoundError(subnet_identifier)
735        return subnet_groups
736
737    def delete_cluster_subnet_group(self, subnet_identifier):
738        if subnet_identifier in self.subnet_groups:
739            return self.subnet_groups.pop(subnet_identifier)
740        raise ClusterSubnetGroupNotFoundError(subnet_identifier)
741
742    def create_cluster_security_group(
743        self, cluster_security_group_name, description, region_name, tags=None
744    ):
745        security_group = SecurityGroup(
746            cluster_security_group_name, description, region_name, tags
747        )
748        self.security_groups[cluster_security_group_name] = security_group
749        return security_group
750
751    def describe_cluster_security_groups(self, security_group_name=None):
752        security_groups = self.security_groups.values()
753        if security_group_name:
754            if security_group_name in self.security_groups:
755                return [self.security_groups[security_group_name]]
756            else:
757                raise ClusterSecurityGroupNotFoundError(security_group_name)
758        return security_groups
759
760    def delete_cluster_security_group(self, security_group_identifier):
761        if security_group_identifier in self.security_groups:
762            return self.security_groups.pop(security_group_identifier)
763        raise ClusterSecurityGroupNotFoundError(security_group_identifier)
764
765    def authorize_cluster_security_group_ingress(self, security_group_name, cidr_ip):
766        security_group = self.security_groups.get(security_group_name)
767        if not security_group:
768            raise ClusterSecurityGroupNotFoundFaultError()
769
770        # just adding the cidr_ip as ingress rule for now as there is no security rule
771        security_group.ingress_rules.append(cidr_ip)
772
773        return security_group
774
775    def create_cluster_parameter_group(
776        self,
777        cluster_parameter_group_name,
778        group_family,
779        description,
780        region_name,
781        tags=None,
782    ):
783        parameter_group = ParameterGroup(
784            cluster_parameter_group_name, group_family, description, region_name, tags
785        )
786        self.parameter_groups[cluster_parameter_group_name] = parameter_group
787
788        return parameter_group
789
790    def describe_cluster_parameter_groups(self, parameter_group_name=None):
791        parameter_groups = self.parameter_groups.values()
792        if parameter_group_name:
793            if parameter_group_name in self.parameter_groups:
794                return [self.parameter_groups[parameter_group_name]]
795            else:
796                raise ClusterParameterGroupNotFoundError(parameter_group_name)
797        return parameter_groups
798
799    def delete_cluster_parameter_group(self, parameter_group_name):
800        if parameter_group_name in self.parameter_groups:
801            return self.parameter_groups.pop(parameter_group_name)
802        raise ClusterParameterGroupNotFoundError(parameter_group_name)
803
804    def create_cluster_snapshot(
805        self, cluster_identifier, snapshot_identifier, region_name, tags
806    ):
807        cluster = self.clusters.get(cluster_identifier)
808        if not cluster:
809            raise ClusterNotFoundError(cluster_identifier)
810        if self.snapshots.get(snapshot_identifier) is not None:
811            raise ClusterSnapshotAlreadyExistsError(snapshot_identifier)
812        snapshot = Snapshot(cluster, snapshot_identifier, region_name, tags)
813        self.snapshots[snapshot_identifier] = snapshot
814        return snapshot
815
816    def describe_cluster_snapshots(
817        self, cluster_identifier=None, snapshot_identifier=None
818    ):
819        if cluster_identifier:
820            cluster_snapshots = []
821            for snapshot in self.snapshots.values():
822                if snapshot.cluster.cluster_identifier == cluster_identifier:
823                    cluster_snapshots.append(snapshot)
824            if cluster_snapshots:
825                return cluster_snapshots
826
827        if snapshot_identifier:
828            if snapshot_identifier in self.snapshots:
829                return [self.snapshots[snapshot_identifier]]
830            raise ClusterSnapshotNotFoundError(snapshot_identifier)
831
832        return self.snapshots.values()
833
834    def delete_cluster_snapshot(self, snapshot_identifier):
835        if snapshot_identifier not in self.snapshots:
836            raise ClusterSnapshotNotFoundError(snapshot_identifier)
837
838        deleted_snapshot = self.snapshots.pop(snapshot_identifier)
839        deleted_snapshot.status = "deleted"
840        return deleted_snapshot
841
842    def restore_from_cluster_snapshot(self, **kwargs):
843        snapshot_identifier = kwargs.pop("snapshot_identifier")
844        snapshot = self.describe_cluster_snapshots(
845            snapshot_identifier=snapshot_identifier
846        )[0]
847        create_kwargs = {
848            "node_type": snapshot.cluster.node_type,
849            "master_username": snapshot.cluster.master_username,
850            "master_user_password": snapshot.cluster.master_user_password,
851            "db_name": snapshot.cluster.db_name,
852            "cluster_type": "multi-node"
853            if snapshot.cluster.number_of_nodes > 1
854            else "single-node",
855            "availability_zone": snapshot.cluster.availability_zone,
856            "port": snapshot.cluster.port,
857            "cluster_version": snapshot.cluster.cluster_version,
858            "number_of_nodes": snapshot.cluster.number_of_nodes,
859            "encrypted": snapshot.cluster.encrypted,
860            "tags": snapshot.cluster.tags,
861            "restored_from_snapshot": True,
862            "enhanced_vpc_routing": snapshot.cluster.enhanced_vpc_routing,
863        }
864        create_kwargs.update(kwargs)
865        return self.create_cluster(**create_kwargs)
866
867    def create_snapshot_copy_grant(self, **kwargs):
868        snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"]
869        kms_key_id = kwargs["kms_key_id"]
870        if snapshot_copy_grant_name not in self.snapshot_copy_grants:
871            snapshot_copy_grant = SnapshotCopyGrant(
872                snapshot_copy_grant_name, kms_key_id
873            )
874            self.snapshot_copy_grants[snapshot_copy_grant_name] = snapshot_copy_grant
875            return snapshot_copy_grant
876        raise SnapshotCopyGrantAlreadyExistsFaultError(snapshot_copy_grant_name)
877
878    def delete_snapshot_copy_grant(self, **kwargs):
879        snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"]
880        if snapshot_copy_grant_name in self.snapshot_copy_grants:
881            return self.snapshot_copy_grants.pop(snapshot_copy_grant_name)
882        raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name)
883
884    def describe_snapshot_copy_grants(self, **kwargs):
885        copy_grants = self.snapshot_copy_grants.values()
886        snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"]
887        if snapshot_copy_grant_name:
888            if snapshot_copy_grant_name in self.snapshot_copy_grants:
889                return [self.snapshot_copy_grants[snapshot_copy_grant_name]]
890            else:
891                raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name)
892        return copy_grants
893
894    def _get_resource_from_arn(self, arn):
895        try:
896            arn_breakdown = arn.split(":")
897            resource_type = arn_breakdown[5]
898            if resource_type == "snapshot":
899                resource_id = arn_breakdown[6].split("/")[1]
900            else:
901                resource_id = arn_breakdown[6]
902        except IndexError:
903            resource_type = resource_id = arn
904        resources = self.RESOURCE_TYPE_MAP.get(resource_type)
905        if resources is None:
906            message = (
907                "Tagging is not supported for this type of resource: '{0}' "
908                "(the ARN is potentially malformed, please check the ARN "
909                "documentation for more information)".format(resource_type)
910            )
911            raise ResourceNotFoundFaultError(message=message)
912        try:
913            resource = resources[resource_id]
914        except KeyError:
915            raise ResourceNotFoundFaultError(resource_type, resource_id)
916        else:
917            return resource
918
919    @staticmethod
920    def _describe_tags_for_resources(resources):
921        tagged_resources = []
922        for resource in resources:
923            for tag in resource.tags:
924                data = {
925                    "ResourceName": resource.arn,
926                    "ResourceType": resource.resource_type,
927                    "Tag": {"Key": tag["Key"], "Value": tag["Value"]},
928                }
929                tagged_resources.append(data)
930        return tagged_resources
931
932    def _describe_tags_for_resource_type(self, resource_type):
933        resources = self.RESOURCE_TYPE_MAP.get(resource_type)
934        if not resources:
935            raise ResourceNotFoundFaultError(resource_type=resource_type)
936        return self._describe_tags_for_resources(resources.values())
937
938    def _describe_tags_for_resource_name(self, resource_name):
939        resource = self._get_resource_from_arn(resource_name)
940        return self._describe_tags_for_resources([resource])
941
942    def create_tags(self, resource_name, tags):
943        resource = self._get_resource_from_arn(resource_name)
944        resource.create_tags(tags)
945
946    def describe_tags(self, resource_name, resource_type):
947        if resource_name and resource_type:
948            raise InvalidParameterValueError(
949                "You cannot filter a list of resources using an Amazon "
950                "Resource Name (ARN) and a resource type together in the "
951                "same request. Retry the request using either an ARN or "
952                "a resource type, but not both."
953            )
954        if resource_type:
955            return self._describe_tags_for_resource_type(resource_type.lower())
956        if resource_name:
957            return self._describe_tags_for_resource_name(resource_name)
958        # If name and type are not specified, return all tagged resources.
959        # TODO: Implement aws marker pagination
960        tagged_resources = []
961        for resource_type in self.RESOURCE_TYPE_MAP:
962            try:
963                tagged_resources += self._describe_tags_for_resource_type(resource_type)
964            except ResourceNotFoundFaultError:
965                pass
966        return tagged_resources
967
968    def delete_tags(self, resource_name, tag_keys):
969        resource = self._get_resource_from_arn(resource_name)
970        resource.delete_tags(tag_keys)
971
972    def get_cluster_credentials(
973        self, cluster_identifier, db_user, auto_create, duration_seconds
974    ):
975        if duration_seconds < 900 or duration_seconds > 3600:
976            raise InvalidParameterValueError(
977                "Token duration must be between 900 and 3600 seconds"
978            )
979        expiration = datetime.datetime.now() + datetime.timedelta(0, duration_seconds)
980        if cluster_identifier in self.clusters:
981            user_prefix = "IAM:" if auto_create is False else "IAMA:"
982            db_user = user_prefix + db_user
983            return {
984                "DbUser": db_user,
985                "DbPassword": random_string(32),
986                "Expiration": expiration,
987            }
988        else:
989            raise ClusterNotFoundError(cluster_identifier)
990
991
992redshift_backends = {}
993for region in Session().get_available_regions("redshift"):
994    redshift_backends[region] = RedshiftBackend(ec2_backends[region], region)
995for region in Session().get_available_regions("redshift", partition_name="aws-us-gov"):
996    redshift_backends[region] = RedshiftBackend(ec2_backends[region], region)
997for region in Session().get_available_regions("redshift", partition_name="aws-cn"):
998    redshift_backends[region] = RedshiftBackend(ec2_backends[region], region)
999