1"""EMRContainersBackend class with methods for supported APIs."""
2import re
3from datetime import datetime
4from boto3 import Session
5
6from moto.core import BaseBackend, BaseModel, ACCOUNT_ID
7from moto.core.utils import iso_8601_datetime_without_milliseconds
8
9from .utils import random_cluster_id, random_job_id, get_partition, paginated_list
10from .exceptions import ResourceNotFoundException
11
12from ..config.exceptions import ValidationException
13
14VIRTUAL_CLUSTER_ARN_TEMPLATE = (
15    "arn:{partition}:emr-containers:{region}:"
16    + str(ACCOUNT_ID)
17    + ":/virtualclusters/{virtual_cluster_id}"
18)
19
20JOB_ARN_TEMPLATE = (
21    "arn:{partition}:emr-containers:{region}:"
22    + str(ACCOUNT_ID)
23    + ":/virtualclusters/{virtual_cluster_id}/jobruns/{job_id}"
24)
25
26# Defaults used for creating a Virtual cluster
27VIRTUAL_CLUSTER_STATUS = "RUNNING"
28JOB_STATUS = "RUNNING"
29
30
31class FakeCluster(BaseModel):
32    def __init__(
33        self,
34        name,
35        container_provider,
36        client_token,
37        region_name,
38        aws_partition,
39        tags=None,
40        virtual_cluster_id=None,
41    ):
42        self.id = virtual_cluster_id or random_cluster_id()
43
44        self.name = name
45        self.client_token = client_token
46        self.arn = VIRTUAL_CLUSTER_ARN_TEMPLATE.format(
47            partition=aws_partition, region=region_name, virtual_cluster_id=self.id
48        )
49        self.state = VIRTUAL_CLUSTER_STATUS
50        self.container_provider = container_provider
51        self.container_provider_id = container_provider["id"]
52        self.namespace = container_provider["info"]["eksInfo"]["namespace"]
53        self.creation_date = iso_8601_datetime_without_milliseconds(
54            datetime.today().replace(hour=0, minute=0, second=0, microsecond=0)
55        )
56        self.tags = tags
57
58    def __iter__(self):
59        yield "id", self.id
60        yield "name", self.name
61        yield "arn", self.arn
62        yield "state", self.state
63        yield "containerProvider", self.container_provider
64        yield "createdAt", self.creation_date
65        yield "tags", self.tags
66
67    def to_dict(self):
68        # Format for summary https://docs.aws.amazon.com/emr-on-eks/latest/APIReference/API_DescribeVirtualCluster.html
69        # (response syntax section)
70        return {
71            "id": self.id,
72            "name": self.name,
73            "arn": self.arn,
74            "state": self.state,
75            "containerProvider": self.container_provider,
76            "createdAt": self.creation_date,
77            "tags": self.tags,
78        }
79
80
81class FakeJob(BaseModel):
82    def __init__(
83        self,
84        name,
85        virtual_cluster_id,
86        client_token,
87        execution_role_arn,
88        release_label,
89        job_driver,
90        configuration_overrides,
91        region_name,
92        aws_partition,
93        tags,
94    ):
95        self.id = random_job_id()
96        self.name = name
97        self.virtual_cluster_id = virtual_cluster_id
98        self.arn = JOB_ARN_TEMPLATE.format(
99            partition=aws_partition,
100            region=region_name,
101            virtual_cluster_id=self.virtual_cluster_id,
102            job_id=self.id,
103        )
104        self.state = JOB_STATUS
105        self.client_token = client_token
106        self.execution_role_arn = execution_role_arn
107        self.release_label = release_label
108        self.job_driver = job_driver
109        self.configuration_overrides = configuration_overrides
110        self.created_at = iso_8601_datetime_without_milliseconds(
111            datetime.today().replace(hour=0, minute=0, second=0, microsecond=0)
112        )
113        self.created_by = None
114        self.finished_at = None
115        self.state_details = None
116        self.failure_reason = None
117        self.tags = tags
118
119    def __iter__(self):
120        yield "id", self.id
121        yield "name", self.name
122        yield "virtualClusterId", self.virtual_cluster_id
123        yield "arn", self.arn
124        yield "state", self.state
125        yield "clientToken", self.client_token
126        yield "executionRoleArn", self.execution_role_arn
127        yield "releaseLabel", self.release_label
128        yield "configurationOverrides", self.release_label
129        yield "jobDriver", self.job_driver
130        yield "createdAt", self.created_at
131        yield "createdBy", self.created_by
132        yield "finishedAt", self.finished_at
133        yield "stateDetails", self.state_details
134        yield "failureReason", self.failure_reason
135        yield "tags", self.tags
136
137    def to_dict(self):
138        # Format for summary https://docs.aws.amazon.com/emr-on-eks/latest/APIReference/API_DescribeJobRun.html
139        # (response syntax section)
140        return {
141            "id": self.id,
142            "name": self.name,
143            "virtualClusterId": self.virtual_cluster_id,
144            "arn": self.arn,
145            "state": self.state,
146            "clientToken": self.client_token,
147            "executionRoleArn": self.execution_role_arn,
148            "releaseLabel": self.release_label,
149            "configurationOverrides": self.configuration_overrides,
150            "jobDriver": self.job_driver,
151            "createdAt": self.created_at,
152            "createdBy": self.created_by,
153            "finishedAt": self.finished_at,
154            "stateDetails": self.state_details,
155            "failureReason": self.failure_reason,
156            "tags": self.tags,
157        }
158
159
160class EMRContainersBackend(BaseBackend):
161    """Implementation of EMRContainers APIs."""
162
163    def __init__(self, region_name=None):
164        super(EMRContainersBackend, self).__init__()
165        self.virtual_clusters = dict()
166        self.virtual_cluster_count = 0
167        self.jobs = dict()
168        self.job_count = 0
169        self.region_name = region_name
170        self.partition = get_partition(region_name)
171
172    def reset(self):
173        """Re-initialize all attributes for this instance."""
174        region_name = self.region_name
175        self.__dict__ = {}
176        self.__init__(region_name)
177
178    def create_virtual_cluster(self, name, container_provider, client_token, tags=None):
179        occupied_namespaces = [
180            virtual_cluster.namespace
181            for virtual_cluster in self.virtual_clusters.values()
182        ]
183
184        if container_provider["info"]["eksInfo"]["namespace"] in occupied_namespaces:
185            raise ValidationException(
186                "A virtual cluster already exists in the given namespace"
187            )
188
189        virtual_cluster = FakeCluster(
190            name=name,
191            container_provider=container_provider,
192            client_token=client_token,
193            tags=tags,
194            region_name=self.region_name,
195            aws_partition=self.partition,
196        )
197
198        self.virtual_clusters[virtual_cluster.id] = virtual_cluster
199        self.virtual_cluster_count += 1
200        return virtual_cluster
201
202    def delete_virtual_cluster(self, id):
203        if id not in self.virtual_clusters:
204            raise ValidationException("VirtualCluster does not exist")
205
206        self.virtual_clusters[id].state = "TERMINATED"
207        return self.virtual_clusters[id]
208
209    def describe_virtual_cluster(self, id):
210        if id not in self.virtual_clusters:
211            raise ValidationException(f"Virtual cluster {id} doesn't exist.")
212
213        return self.virtual_clusters[id].to_dict()
214
215    def list_virtual_clusters(
216        self,
217        container_provider_id,
218        container_provider_type,
219        created_after,
220        created_before,
221        states,
222        max_results,
223        next_token,
224    ):
225        virtual_clusters = [
226            virtual_cluster.to_dict()
227            for virtual_cluster in self.virtual_clusters.values()
228        ]
229
230        if container_provider_id:
231            virtual_clusters = [
232                virtual_cluster
233                for virtual_cluster in virtual_clusters
234                if virtual_cluster["containerProvider"]["id"] == container_provider_id
235            ]
236
237        if container_provider_type:
238            virtual_clusters = [
239                virtual_cluster
240                for virtual_cluster in virtual_clusters
241                if virtual_cluster["containerProvider"]["type"]
242                == container_provider_type
243            ]
244
245        if created_after:
246            virtual_clusters = [
247                virtual_cluster
248                for virtual_cluster in virtual_clusters
249                if virtual_cluster["createdAt"] >= created_after
250            ]
251
252        if created_before:
253            virtual_clusters = [
254                virtual_cluster
255                for virtual_cluster in virtual_clusters
256                if virtual_cluster["createdAt"] <= created_before
257            ]
258
259        if states:
260            virtual_clusters = [
261                virtual_cluster
262                for virtual_cluster in virtual_clusters
263                if virtual_cluster["state"] in states
264            ]
265        sort_key = "name"
266        return paginated_list(virtual_clusters, sort_key, max_results, next_token)
267
268    def start_job_run(
269        self,
270        name,
271        virtual_cluster_id,
272        client_token,
273        execution_role_arn,
274        release_label,
275        job_driver,
276        configuration_overrides,
277        tags,
278    ):
279
280        if virtual_cluster_id not in self.virtual_clusters.keys():
281            raise ResourceNotFoundException(
282                f"Virtual cluster {virtual_cluster_id} doesn't exist."
283            )
284
285        if not re.match(
286            r"emr-[0-9]{1}\.[0-9]{1,2}\.0-(latest|[0-9]{8})", release_label
287        ):
288            raise ResourceNotFoundException(f"Release {release_label} doesn't exist.")
289
290        job = FakeJob(
291            name=name,
292            virtual_cluster_id=virtual_cluster_id,
293            client_token=client_token,
294            execution_role_arn=execution_role_arn,
295            release_label=release_label,
296            job_driver=job_driver,
297            configuration_overrides=configuration_overrides,
298            tags=tags,
299            region_name=self.region_name,
300            aws_partition=self.partition,
301        )
302
303        self.jobs[job.id] = job
304        self.job_count += 1
305        return job
306
307    def cancel_job_run(self, id, virtual_cluster_id):
308
309        if not re.match(r"[a-z,A-Z,0-9]{19}", id):
310            raise ValidationException("Invalid job run short id")
311
312        if id not in self.jobs.keys():
313            raise ResourceNotFoundException(f"Job run {id} doesn't exist.")
314
315        if virtual_cluster_id != self.jobs[id].virtual_cluster_id:
316            raise ResourceNotFoundException(f"Job run {id} doesn't exist.")
317
318        if self.jobs[id].state in [
319            "FAILED",
320            "CANCELLED",
321            "CANCEL_PENDING",
322            "COMPLETED",
323        ]:
324            raise ValidationException(f"Job run {id} is not in a cancellable state")
325
326        job = self.jobs[id]
327        job.state = "CANCELLED"
328        job.finished_at = iso_8601_datetime_without_milliseconds(
329            datetime.today().replace(hour=0, minute=1, second=0, microsecond=0)
330        )
331        job.state_details = "JobRun CANCELLED successfully."
332
333        return job
334
335    def list_job_runs(
336        self,
337        virtual_cluster_id,
338        created_before,
339        created_after,
340        name,
341        states,
342        max_results,
343        next_token,
344    ):
345        jobs = [job.to_dict() for job in self.jobs.values()]
346
347        jobs = [job for job in jobs if job["virtualClusterId"] == virtual_cluster_id]
348
349        if created_after:
350            jobs = [job for job in jobs if job["createdAt"] >= created_after]
351
352        if created_before:
353            jobs = [job for job in jobs if job["createdAt"] <= created_before]
354
355        if states:
356            jobs = [job for job in jobs if job["state"] in states]
357
358        if name:
359            jobs = [job for job in jobs if job["name"] in name]
360
361        sort_key = "id"
362        return paginated_list(jobs, sort_key, max_results, next_token)
363
364    def describe_job_run(self, id, virtual_cluster_id):
365        if not re.match(r"[a-z,A-Z,0-9]{19}", id):
366            raise ValidationException("Invalid job run short id")
367
368        if id not in self.jobs.keys():
369            raise ResourceNotFoundException(f"Job run {id} doesn't exist.")
370
371        if virtual_cluster_id != self.jobs[id].virtual_cluster_id:
372            raise ResourceNotFoundException(f"Job run {id} doesn't exist.")
373
374        return self.jobs[id].to_dict()
375
376
377emrcontainers_backends = {}
378for available_region in Session().get_available_regions("emr-containers"):
379    emrcontainers_backends[available_region] = EMRContainersBackend(available_region)
380for available_region in Session().get_available_regions(
381    "emr-containers", partition_name="aws-us-gov"
382):
383    emrcontainers_backends[available_region] = EMRContainersBackend(available_region)
384for available_region in Session().get_available_regions(
385    "emr-containers", partition_name="aws-cn"
386):
387    emrcontainers_backends[available_region] = EMRContainersBackend(available_region)
388