1from boto3 import Session
2from jinja2 import Template
3
4from moto.core import BaseBackend, CloudFormationModel
5from moto.ec2.models import ec2_backends
6from moto.rds.exceptions import UnformattedGetAttTemplateException
7from moto.rds2.models import rds2_backends
8
9
10class Database(CloudFormationModel):
11    @classmethod
12    def has_cfn_attr(cls, attribute):
13        return attribute in ["Endpoint.Address", "Endpoint.Port"]
14
15    def get_cfn_attribute(self, attribute_name):
16        if attribute_name == "Endpoint.Address":
17            return self.address
18        elif attribute_name == "Endpoint.Port":
19            return self.port
20        raise UnformattedGetAttTemplateException()
21
22    @staticmethod
23    def cloudformation_name_type():
24        return "DBInstanceIdentifier"
25
26    @staticmethod
27    def cloudformation_type():
28        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-rds-dbinstance.html
29        return "AWS::RDS::DBInstance"
30
31    @classmethod
32    def create_from_cloudformation_json(
33        cls, resource_name, cloudformation_json, region_name, **kwargs
34    ):
35        properties = cloudformation_json["Properties"]
36
37        db_security_groups = properties.get("DBSecurityGroups")
38        if not db_security_groups:
39            db_security_groups = []
40        security_groups = [group.group_name for group in db_security_groups]
41        db_subnet_group = properties.get("DBSubnetGroupName")
42        db_subnet_group_name = db_subnet_group.subnet_name if db_subnet_group else None
43        db_kwargs = {
44            "auto_minor_version_upgrade": properties.get("AutoMinorVersionUpgrade"),
45            "allocated_storage": properties.get("AllocatedStorage"),
46            "availability_zone": properties.get("AvailabilityZone"),
47            "backup_retention_period": properties.get("BackupRetentionPeriod"),
48            "db_instance_class": properties.get("DBInstanceClass"),
49            "db_instance_identifier": resource_name,
50            "db_name": properties.get("DBName"),
51            "db_subnet_group_name": db_subnet_group_name,
52            "engine": properties.get("Engine"),
53            "engine_version": properties.get("EngineVersion"),
54            "iops": properties.get("Iops"),
55            "kms_key_id": properties.get("KmsKeyId"),
56            "master_password": properties.get("MasterUserPassword"),
57            "master_username": properties.get("MasterUsername"),
58            "multi_az": properties.get("MultiAZ"),
59            "port": properties.get("Port", 3306),
60            "publicly_accessible": properties.get("PubliclyAccessible"),
61            "copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"),
62            "region": region_name,
63            "security_groups": security_groups,
64            "storage_encrypted": properties.get("StorageEncrypted"),
65            "storage_type": properties.get("StorageType"),
66            "tags": properties.get("Tags"),
67        }
68
69        rds_backend = rds_backends[region_name]
70        source_db_identifier = properties.get("SourceDBInstanceIdentifier")
71        if source_db_identifier:
72            # Replica
73            db_kwargs["source_db_identifier"] = source_db_identifier
74            database = rds_backend.create_database_replica(db_kwargs)
75        else:
76            database = rds_backend.create_database(db_kwargs)
77        return database
78
79    def to_xml(self):
80        template = Template(
81            """<DBInstance>
82              <BackupRetentionPeriod>{{ database.backup_retention_period }}</BackupRetentionPeriod>
83              <DBInstanceStatus>{{ database.status }}</DBInstanceStatus>
84              <MultiAZ>{{ database.multi_az }}</MultiAZ>
85              <VpcSecurityGroups/>
86              <DBInstanceIdentifier>{{ database.db_instance_identifier }}</DBInstanceIdentifier>
87              <PreferredBackupWindow>03:50-04:20</PreferredBackupWindow>
88              <PreferredMaintenanceWindow>wed:06:38-wed:07:08</PreferredMaintenanceWindow>
89              <ReadReplicaDBInstanceIdentifiers>
90                {% for replica_id in database.replicas %}
91                    <ReadReplicaDBInstanceIdentifier>{{ replica_id }}</ReadReplicaDBInstanceIdentifier>
92                {% endfor %}
93              </ReadReplicaDBInstanceIdentifiers>
94              <StatusInfos>
95                {% if database.is_replica %}
96                <DBInstanceStatusInfo>
97                    <StatusType>read replication</StatusType>
98                    <Status>replicating</Status>
99                    <Normal>true</Normal>
100                    <Message></Message>
101                </DBInstanceStatusInfo>
102                {% endif %}
103              </StatusInfos>
104              {% if database.is_replica %}
105              <ReadReplicaSourceDBInstanceIdentifier>{{ database.source_db_identifier }}</ReadReplicaSourceDBInstanceIdentifier>
106              {% endif %}
107              <Engine>{{ database.engine }}</Engine>
108              <LicenseModel>{{ database.license_model }}</LicenseModel>
109              <EngineVersion>{{ database.engine_version }}</EngineVersion>
110              <DBParameterGroups>
111              </DBParameterGroups>
112              <OptionGroupMemberships>
113              </OptionGroupMemberships>
114              <DBSecurityGroups>
115                {% for security_group in database.security_groups %}
116                <DBSecurityGroup>
117                  <Status>active</Status>
118                  <DBSecurityGroupName>{{ security_group }}</DBSecurityGroupName>
119                </DBSecurityGroup>
120                {% endfor %}
121              </DBSecurityGroups>
122              {% if database.db_subnet_group %}
123              <DBSubnetGroup>
124                <DBSubnetGroupName>{{ database.db_subnet_group.subnet_name }}</DBSubnetGroupName>
125                <DBSubnetGroupDescription>{{ database.db_subnet_group.description }}</DBSubnetGroupDescription>
126                <SubnetGroupStatus>{{ database.db_subnet_group.status }}</SubnetGroupStatus>
127                <Subnets>
128                    {% for subnet in database.db_subnet_group.subnets %}
129                    <Subnet>
130                      <SubnetStatus>Active</SubnetStatus>
131                      <SubnetIdentifier>{{ subnet.id }}</SubnetIdentifier>
132                      <SubnetAvailabilityZone>
133                        <Name>{{ subnet.availability_zone }}</Name>
134                        <ProvisionedIopsCapable>false</ProvisionedIopsCapable>
135                      </SubnetAvailabilityZone>
136                    </Subnet>
137                    {% endfor %}
138                </Subnets>
139                <VpcId>{{ database.db_subnet_group.vpc_id }}</VpcId>
140              </DBSubnetGroup>
141              {% endif %}
142              <PubliclyAccessible>{{ database.publicly_accessible }}</PubliclyAccessible>
143              <CopyTagsToSnapshot>{{ database.copy_tags_to_snapshot }}</CopyTagsToSnapshot>
144              <AutoMinorVersionUpgrade>{{ database.auto_minor_version_upgrade }}</AutoMinorVersionUpgrade>
145              <AllocatedStorage>{{ database.allocated_storage }}</AllocatedStorage>
146              <StorageEncrypted>{{ database.storage_encrypted }}</StorageEncrypted>
147              {% if database.kms_key_id %}
148              <KmsKeyId>{{ database.kms_key_id }}</KmsKeyId>
149              {% endif %}
150              {% if database.iops %}
151              <Iops>{{ database.iops }}</Iops>
152              <StorageType>io1</StorageType>
153              {% else %}
154              <StorageType>{{ database.storage_type }}</StorageType>
155              {% endif %}
156              <DBInstanceClass>{{ database.db_instance_class }}</DBInstanceClass>
157              <InstanceCreateTime>{{ database.instance_create_time }}</InstanceCreateTime>
158              <MasterUsername>{{ database.master_username }}</MasterUsername>
159              <Endpoint>
160                <Address>{{ database.address }}</Address>
161                <Port>{{ database.port }}</Port>
162              </Endpoint>
163              <DBInstanceArn>{{ database.db_instance_arn }}</DBInstanceArn>
164            </DBInstance>"""
165        )
166        return template.render(database=self)
167
168    def delete(self, region_name):
169        backend = rds_backends[region_name]
170        backend.delete_database(self.db_instance_identifier)
171
172
173class SecurityGroup(CloudFormationModel):
174    def __init__(self, group_name, description):
175        self.group_name = group_name
176        self.description = description
177        self.status = "authorized"
178        self.ip_ranges = []
179        self.ec2_security_groups = []
180
181    def to_xml(self):
182        template = Template(
183            """<DBSecurityGroup>
184            <EC2SecurityGroups>
185            {% for security_group in security_group.ec2_security_groups %}
186                <EC2SecurityGroup>
187                    <EC2SecurityGroupId>{{ security_group.id }}</EC2SecurityGroupId>
188                    <EC2SecurityGroupName>{{ security_group.name }}</EC2SecurityGroupName>
189                    <EC2SecurityGroupOwnerId>{{ security_group.owner_id }}</EC2SecurityGroupOwnerId>
190                    <Status>authorized</Status>
191                </EC2SecurityGroup>
192            {% endfor %}
193            </EC2SecurityGroups>
194
195            <DBSecurityGroupDescription>{{ security_group.description }}</DBSecurityGroupDescription>
196            <IPRanges>
197            {% for ip_range in security_group.ip_ranges %}
198                <IPRange>
199                    <CIDRIP>{{ ip_range }}</CIDRIP>
200                    <Status>authorized</Status>
201                </IPRange>
202            {% endfor %}
203            </IPRanges>
204            <OwnerId>{{ security_group.ownder_id }}</OwnerId>
205            <DBSecurityGroupName>{{ security_group.group_name }}</DBSecurityGroupName>
206        </DBSecurityGroup>"""
207        )
208        return template.render(security_group=self)
209
210    def authorize_cidr(self, cidr_ip):
211        self.ip_ranges.append(cidr_ip)
212
213    def authorize_security_group(self, security_group):
214        self.ec2_security_groups.append(security_group)
215
216    @staticmethod
217    def cloudformation_name_type():
218        return None
219
220    @staticmethod
221    def cloudformation_type():
222        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-rds-dbsecuritygroup.html
223        return "AWS::RDS::DBSecurityGroup"
224
225    @classmethod
226    def create_from_cloudformation_json(
227        cls, resource_name, cloudformation_json, region_name, **kwargs
228    ):
229        properties = cloudformation_json["Properties"]
230        group_name = resource_name.lower()
231        description = properties["GroupDescription"]
232        security_group_ingress_rules = properties.get("DBSecurityGroupIngress", [])
233        tags = properties.get("Tags")
234
235        ec2_backend = ec2_backends[region_name]
236        rds_backend = rds_backends[region_name]
237        security_group = rds_backend.create_security_group(
238            group_name, description, tags
239        )
240
241        for security_group_ingress in security_group_ingress_rules:
242            for ingress_type, ingress_value in security_group_ingress.items():
243                if ingress_type == "CIDRIP":
244                    security_group.authorize_cidr(ingress_value)
245                elif ingress_type == "EC2SecurityGroupName":
246                    subnet = ec2_backend.get_security_group_from_name(ingress_value)
247                    security_group.authorize_security_group(subnet)
248                elif ingress_type == "EC2SecurityGroupId":
249                    subnet = ec2_backend.get_security_group_from_id(ingress_value)
250                    security_group.authorize_security_group(subnet)
251        return security_group
252
253    def delete(self, region_name):
254        backend = rds_backends[region_name]
255        backend.delete_security_group(self.group_name)
256
257
258class SubnetGroup(CloudFormationModel):
259    def __init__(self, subnet_name, description, subnets):
260        self.subnet_name = subnet_name
261        self.description = description
262        self.subnets = subnets
263        self.status = "Complete"
264
265        self.vpc_id = self.subnets[0].vpc_id
266
267    def to_xml(self):
268        template = Template(
269            """<DBSubnetGroup>
270              <VpcId>{{ subnet_group.vpc_id }}</VpcId>
271              <SubnetGroupStatus>{{ subnet_group.status }}</SubnetGroupStatus>
272              <DBSubnetGroupDescription>{{ subnet_group.description }}</DBSubnetGroupDescription>
273              <DBSubnetGroupName>{{ subnet_group.subnet_name }}</DBSubnetGroupName>
274              <Subnets>
275                {% for subnet in subnet_group.subnets %}
276                <Subnet>
277                  <SubnetStatus>Active</SubnetStatus>
278                  <SubnetIdentifier>{{ subnet.id }}</SubnetIdentifier>
279                  <SubnetAvailabilityZone>
280                    <Name>{{ subnet.availability_zone }}</Name>
281                    <ProvisionedIopsCapable>false</ProvisionedIopsCapable>
282                  </SubnetAvailabilityZone>
283                </Subnet>
284                {% endfor %}
285              </Subnets>
286            </DBSubnetGroup>"""
287        )
288        return template.render(subnet_group=self)
289
290    @staticmethod
291    def cloudformation_name_type():
292        return "DBSubnetGroupName"
293
294    @staticmethod
295    def cloudformation_type():
296        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-rds-dbsubnetgroup.html
297        return "AWS::RDS::DBSubnetGroup"
298
299    @classmethod
300    def create_from_cloudformation_json(
301        cls, resource_name, cloudformation_json, region_name, **kwargs
302    ):
303        properties = cloudformation_json["Properties"]
304        subnet_name = resource_name.lower()
305        description = properties["DBSubnetGroupDescription"]
306        subnet_ids = properties["SubnetIds"]
307        tags = properties.get("Tags")
308
309        ec2_backend = ec2_backends[region_name]
310        subnets = [ec2_backend.get_subnet(subnet_id) for subnet_id in subnet_ids]
311        rds_backend = rds_backends[region_name]
312        subnet_group = rds_backend.create_subnet_group(
313            subnet_name, description, subnets, tags
314        )
315        return subnet_group
316
317    def delete(self, region_name):
318        backend = rds_backends[region_name]
319        backend.delete_subnet_group(self.subnet_name)
320
321
322class RDSBackend(BaseBackend):
323    def __init__(self, region):
324        self.region = region
325
326    def __getattr__(self, attr):
327        return self.rds2_backend().__getattribute__(attr)
328
329    def reset(self):
330        # preserve region
331        region = self.region
332        self.rds2_backend().reset()
333        self.__dict__ = {}
334        self.__init__(region)
335
336    def rds2_backend(self):
337        return rds2_backends[self.region]
338
339
340rds_backends = {}
341for region in Session().get_available_regions("rds"):
342    rds_backends[region] = RDSBackend(region)
343for region in Session().get_available_regions("rds", partition_name="aws-us-gov"):
344    rds_backends[region] = RDSBackend(region)
345for region in Session().get_available_regions("rds", partition_name="aws-cn"):
346    rds_backends[region] = RDSBackend(region)
347