1import re
2from itertools import cycle
3import datetime
4import time
5import uuid
6import logging
7import docker
8import threading
9import dateutil.parser
10from boto3 import Session
11
12from moto.core import BaseBackend, BaseModel, CloudFormationModel
13from moto.iam import iam_backends
14from moto.ec2 import ec2_backends
15from moto.ecs import ecs_backends
16from moto.logs import logs_backends
17
18from .exceptions import InvalidParameterValueException, ClientException, ValidationError
19from .utils import (
20    make_arn_for_compute_env,
21    make_arn_for_job_queue,
22    make_arn_for_task_def,
23    lowercase_first_key,
24)
25from moto.ec2.exceptions import InvalidSubnetIdError
26from moto.ec2.models import INSTANCE_TYPES as EC2_INSTANCE_TYPES
27from moto.iam.exceptions import IAMNotFoundException
28from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID
29from moto.core.utils import unix_time_millis
30from moto.utilities.docker_utilities import DockerModel, parse_image_ref
31from ..utilities.tagging_service import TaggingService
32
33logger = logging.getLogger(__name__)
34COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile(
35    r"^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$"
36)
37
38
39def datetime2int_milliseconds(date):
40    """
41    AWS returns timestamps in milliseconds
42    We don't use milliseconds timestamps internally,
43    this method should be used only in describe() method
44    """
45    return int(date.timestamp() * 1000)
46
47
48def datetime2int(date):
49    return int(time.mktime(date.timetuple()))
50
51
52class ComputeEnvironment(CloudFormationModel):
53    def __init__(
54        self,
55        compute_environment_name,
56        _type,
57        state,
58        compute_resources,
59        service_role,
60        region_name,
61    ):
62        self.name = compute_environment_name
63        self.env_type = _type
64        self.state = state
65        self.compute_resources = compute_resources
66        self.service_role = service_role
67        self.arn = make_arn_for_compute_env(
68            DEFAULT_ACCOUNT_ID, compute_environment_name, region_name
69        )
70
71        self.instances = []
72        self.ecs_arn = None
73        self.ecs_name = None
74
75    def add_instance(self, instance):
76        self.instances.append(instance)
77
78    def set_ecs(self, arn, name):
79        self.ecs_arn = arn
80        self.ecs_name = name
81
82    @property
83    def physical_resource_id(self):
84        return self.arn
85
86    @staticmethod
87    def cloudformation_name_type():
88        return "ComputeEnvironmentName"
89
90    @staticmethod
91    def cloudformation_type():
92        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-computeenvironment.html
93        return "AWS::Batch::ComputeEnvironment"
94
95    @classmethod
96    def create_from_cloudformation_json(
97        cls, resource_name, cloudformation_json, region_name, **kwargs
98    ):
99        backend = batch_backends[region_name]
100        properties = cloudformation_json["Properties"]
101
102        env = backend.create_compute_environment(
103            resource_name,
104            properties["Type"],
105            properties.get("State", "ENABLED"),
106            lowercase_first_key(properties["ComputeResources"]),
107            properties["ServiceRole"],
108        )
109        arn = env[1]
110
111        return backend.get_compute_environment_by_arn(arn)
112
113
114class JobQueue(CloudFormationModel):
115    def __init__(
116        self, name, priority, state, environments, env_order_json, region_name
117    ):
118        """
119        :param name: Job queue name
120        :type name: str
121        :param priority: Job queue priority
122        :type priority: int
123        :param state: Either ENABLED or DISABLED
124        :type state: str
125        :param environments: Compute Environments
126        :type environments: list of ComputeEnvironment
127        :param env_order_json: Compute Environments JSON for use when describing
128        :type env_order_json: list of dict
129        :param region_name: Region name
130        :type region_name: str
131        """
132        self.name = name
133        self.priority = priority
134        self.state = state
135        self.environments = environments
136        self.env_order_json = env_order_json
137        self.arn = make_arn_for_job_queue(DEFAULT_ACCOUNT_ID, name, region_name)
138        self.status = "VALID"
139
140        self.jobs = []
141
142    def describe(self):
143        result = {
144            "computeEnvironmentOrder": self.env_order_json,
145            "jobQueueArn": self.arn,
146            "jobQueueName": self.name,
147            "priority": self.priority,
148            "state": self.state,
149            "status": self.status,
150        }
151
152        return result
153
154    @property
155    def physical_resource_id(self):
156        return self.arn
157
158    @staticmethod
159    def cloudformation_name_type():
160        return "JobQueueName"
161
162    @staticmethod
163    def cloudformation_type():
164        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobqueue.html
165        return "AWS::Batch::JobQueue"
166
167    @classmethod
168    def create_from_cloudformation_json(
169        cls, resource_name, cloudformation_json, region_name, **kwargs
170    ):
171        backend = batch_backends[region_name]
172        properties = cloudformation_json["Properties"]
173
174        # Need to deal with difference case from cloudformation compute_resources, e.g. instanceRole vs InstanceRole
175        # Hacky fix to normalise keys, is making me think I want to start spamming cAsEiNsEnSiTiVe dictionaries
176        compute_envs = [
177            lowercase_first_key(dict_item)
178            for dict_item in properties["ComputeEnvironmentOrder"]
179        ]
180
181        queue = backend.create_job_queue(
182            queue_name=resource_name,
183            priority=properties["Priority"],
184            state=properties.get("State", "ENABLED"),
185            compute_env_order=compute_envs,
186        )
187        arn = queue[1]
188
189        return backend.get_job_queue_by_arn(arn)
190
191
192class JobDefinition(CloudFormationModel):
193    def __init__(
194        self,
195        name,
196        parameters,
197        _type,
198        container_properties,
199        region_name,
200        tags={},
201        revision=0,
202        retry_strategy=0,
203    ):
204        self.name = name
205        self.retries = retry_strategy
206        self.type = _type
207        self.revision = revision
208        self._region = region_name
209        self.container_properties = container_properties
210        self.arn = None
211        self.status = "ACTIVE"
212        self.tagger = TaggingService()
213        if parameters is None:
214            parameters = {}
215        self.parameters = parameters
216
217        self._validate()
218        self._update_arn()
219
220        tags = self._format_tags(tags)
221        # Validate the tags before proceeding.
222        errmsg = self.tagger.validate_tags(tags or [])
223        if errmsg:
224            raise ValidationError(errmsg)
225
226        self.tagger.tag_resource(self.arn, tags or [])
227
228    def _format_tags(self, tags):
229        return [{"Key": k, "Value": v} for k, v in tags.items()]
230
231    def _update_arn(self):
232        self.revision += 1
233        self.arn = make_arn_for_task_def(
234            DEFAULT_ACCOUNT_ID, self.name, self.revision, self._region
235        )
236
237    def _get_resource_requirement(self, req_type, default=None):
238        """
239        Get resource requirement from container properties.
240
241        Resource requirements like "memory" and "vcpus" are now specified in
242        "resourceRequirements". This function retrieves a resource requirement
243        from either container_properties.resourceRequirements (preferred) or
244        directly from container_properties (deprecated).
245
246        :param req_type: The type of resource requirement to retrieve.
247        :type req_type: ["gpu", "memory", "vcpus"]
248
249        :param default: The default value to return if the resource requirement is not found.
250        :type default: any, default=None
251
252        :return: The value of the resource requirement, or None.
253        :rtype: any
254        """
255        resource_reqs = self.container_properties.get("resourceRequirements", [])
256
257        # Filter the resource requirements by the specified type.
258        # Note that VCPUS are specified in resourceRequirements without the
259        # trailing "s", so we strip that off in the comparison below.
260        required_resource = list(
261            filter(
262                lambda req: req["type"].lower() == req_type.lower().rstrip("s"),
263                resource_reqs,
264            )
265        )
266
267        if required_resource:
268            return required_resource[0]["value"]
269        else:
270            return self.container_properties.get(req_type, default)
271
272    def _validate(self):
273        if self.type not in ("container",):
274            raise ClientException('type must be one of "container"')
275
276        # For future use when containers arnt the only thing in batch
277        if self.type != "container":
278            raise NotImplementedError()
279
280        if not isinstance(self.parameters, dict):
281            raise ClientException("parameters must be a string to string map")
282
283        if "image" not in self.container_properties:
284            raise ClientException("containerProperties must contain image")
285
286        memory = self._get_resource_requirement("memory")
287        if memory is None:
288            raise ClientException("containerProperties must contain memory")
289        if memory < 4:
290            raise ClientException("container memory limit must be greater than 4")
291
292        vcpus = self._get_resource_requirement("vcpus")
293        if vcpus is None:
294            raise ClientException("containerProperties must contain vcpus")
295        if vcpus < 1:
296            raise ClientException("container vcpus limit must be greater than 0")
297
298    def update(self, parameters, _type, container_properties, retry_strategy):
299        if parameters is None:
300            parameters = self.parameters
301
302        if _type is None:
303            _type = self.type
304
305        if container_properties is None:
306            container_properties = self.container_properties
307
308        if retry_strategy is None:
309            retry_strategy = self.retries
310
311        return JobDefinition(
312            self.name,
313            parameters,
314            _type,
315            container_properties,
316            region_name=self._region,
317            revision=self.revision,
318            retry_strategy=retry_strategy,
319        )
320
321    def describe(self):
322        result = {
323            "jobDefinitionArn": self.arn,
324            "jobDefinitionName": self.name,
325            "parameters": self.parameters,
326            "revision": self.revision,
327            "status": self.status,
328            "type": self.type,
329            "tags": self.tagger.get_tag_dict_for_resource(self.arn),
330        }
331        if self.container_properties is not None:
332            result["containerProperties"] = self.container_properties
333        if self.retries is not None and self.retries > 0:
334            result["retryStrategy"] = {"attempts": self.retries}
335
336        return result
337
338    @property
339    def physical_resource_id(self):
340        return self.arn
341
342    @staticmethod
343    def cloudformation_name_type():
344        return "JobDefinitionName"
345
346    @staticmethod
347    def cloudformation_type():
348        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobdefinition.html
349        return "AWS::Batch::JobDefinition"
350
351    @classmethod
352    def create_from_cloudformation_json(
353        cls, resource_name, cloudformation_json, region_name, **kwargs
354    ):
355        backend = batch_backends[region_name]
356        properties = cloudformation_json["Properties"]
357        res = backend.register_job_definition(
358            def_name=resource_name,
359            parameters=lowercase_first_key(properties.get("Parameters", {})),
360            _type="container",
361            tags=lowercase_first_key(properties.get("Tags", {})),
362            retry_strategy=lowercase_first_key(properties["RetryStrategy"]),
363            container_properties=lowercase_first_key(properties["ContainerProperties"]),
364        )
365        arn = res[1]
366
367        return backend.get_job_definition_by_arn(arn)
368
369
370class Job(threading.Thread, BaseModel, DockerModel):
371    def __init__(
372        self,
373        name,
374        job_def,
375        job_queue,
376        log_backend,
377        container_overrides,
378        depends_on,
379        all_jobs,
380    ):
381        """
382        Docker Job
383
384        :param name: Job Name
385        :param job_def: Job definition
386        :type: job_def: JobDefinition
387        :param job_queue: Job Queue
388        :param log_backend: Log backend
389        :type log_backend: moto.logs.models.LogsBackend
390        """
391        threading.Thread.__init__(self)
392        DockerModel.__init__(self)
393
394        self.job_name = name
395        self.job_id = str(uuid.uuid4())
396        self.job_definition = job_def
397        self.container_overrides = container_overrides or {}
398        self.job_queue = job_queue
399        self.job_state = "SUBMITTED"  # One of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED
400        self.job_queue.jobs.append(self)
401        self.job_created_at = datetime.datetime.now()
402        self.job_started_at = datetime.datetime(1970, 1, 1)
403        self.job_stopped_at = datetime.datetime(1970, 1, 1)
404        self.job_stopped = False
405        self.job_stopped_reason = None
406        self.depends_on = depends_on
407        self.all_jobs = all_jobs
408
409        self.stop = False
410
411        self.daemon = True
412        self.name = "MOTO-BATCH-" + self.job_id
413
414        self._log_backend = log_backend
415        self.log_stream_name = None
416
417    def describe(self):
418        result = {
419            "jobDefinition": self.job_definition.arn,
420            "jobId": self.job_id,
421            "jobName": self.job_name,
422            "jobQueue": self.job_queue.arn,
423            "status": self.job_state,
424            "dependsOn": self.depends_on if self.depends_on else [],
425            "createdAt": datetime2int_milliseconds(self.job_created_at),
426        }
427        if result["status"] not in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING"]:
428            result["startedAt"] = datetime2int_milliseconds(self.job_started_at)
429        if self.job_stopped:
430            result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at)
431            result["container"] = {}
432            result["container"]["command"] = self._get_container_property("command", [])
433            result["container"]["privileged"] = self._get_container_property(
434                "privileged", False
435            )
436            result["container"][
437                "readonlyRootFilesystem"
438            ] = self._get_container_property("readonlyRootFilesystem", False)
439            result["container"]["ulimits"] = self._get_container_property("ulimits", {})
440            result["container"]["vcpus"] = self._get_container_property("vcpus", 1)
441            result["container"]["memory"] = self._get_container_property("memory", 512)
442            result["container"]["volumes"] = self._get_container_property("volumes", [])
443            result["container"]["environment"] = self._get_container_property(
444                "environment", []
445            )
446            result["container"]["logStreamName"] = self.log_stream_name
447        if self.job_stopped_reason is not None:
448            result["statusReason"] = self.job_stopped_reason
449        return result
450
451    def _get_container_property(self, p, default):
452        if p == "environment":
453            job_env = self.container_overrides.get(p, default)
454            jd_env = self.job_definition.container_properties.get(p, default)
455
456            job_env_dict = {_env["name"]: _env["value"] for _env in job_env}
457            jd_env_dict = {_env["name"]: _env["value"] for _env in jd_env}
458
459            for key in jd_env_dict.keys():
460                if key not in job_env_dict.keys():
461                    job_env.append({"name": key, "value": jd_env_dict[key]})
462
463            job_env.append({"name": "AWS_BATCH_JOB_ID", "value": self.job_id})
464
465            return job_env
466
467        if p in ["vcpus", "memory"]:
468            return self.container_overrides.get(
469                p, self.job_definition._get_resource_requirement(p, default)
470            )
471
472        return self.container_overrides.get(
473            p, self.job_definition.container_properties.get(p, default)
474        )
475
476    def run(self):
477        """
478        Run the container.
479
480        Logic is as follows:
481        Generate container info (eventually from task definition)
482        Start container
483        Loop whilst not asked to stop and the container is running.
484          Get all logs from container between the last time I checked and now.
485        Convert logs into cloudwatch format
486        Put logs into cloudwatch
487
488        :return:
489        """
490        try:
491            self.job_state = "PENDING"
492
493            if self.depends_on and not self._wait_for_dependencies():
494                return
495
496            image = self.job_definition.container_properties.get(
497                "image", "alpine:latest"
498            )
499            privileged = self.job_definition.container_properties.get(
500                "privileged", False
501            )
502            cmd = self._get_container_property(
503                "command",
504                '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"',
505            )
506            environment = {
507                e["name"]: e["value"]
508                for e in self._get_container_property("environment", [])
509            }
510            volumes = {
511                v["name"]: v["host"]
512                for v in self._get_container_property("volumes", [])
513            }
514            mounts = [
515                docker.types.Mount(
516                    m["containerPath"],
517                    volumes[m["sourceVolume"]]["sourcePath"],
518                    type="bind",
519                    read_only=m["readOnly"],
520                )
521                for m in self._get_container_property("mountPoints", [])
522            ]
523            name = "{0}-{1}".format(self.job_name, self.job_id)
524
525            self.job_state = "RUNNABLE"
526            # TODO setup ecs container instance
527
528            self.job_started_at = datetime.datetime.now()
529
530            log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
531            image_repository, image_tag = parse_image_ref(image)
532            # avoid explicit pulling here, to allow using cached images
533            # self.docker_client.images.pull(image_repository, image_tag)
534            self.job_state = "STARTING"
535            container = self.docker_client.containers.run(
536                image,
537                cmd,
538                detach=True,
539                name=name,
540                log_config=log_config,
541                environment=environment,
542                mounts=mounts,
543                privileged=privileged,
544            )
545            self.job_state = "RUNNING"
546            try:
547                container.reload()
548                while container.status == "running" and not self.stop:
549                    container.reload()
550                    time.sleep(0.5)
551
552                # Container should be stopped by this point... unless asked to stop
553                if container.status == "running":
554                    container.kill()
555
556                # Log collection
557                logs_stdout = []
558                logs_stderr = []
559                logs_stderr.extend(
560                    container.logs(
561                        stdout=False,
562                        stderr=True,
563                        timestamps=True,
564                        since=datetime2int(self.job_started_at),
565                    )
566                    .decode()
567                    .split("\n")
568                )
569                logs_stdout.extend(
570                    container.logs(
571                        stdout=True,
572                        stderr=False,
573                        timestamps=True,
574                        since=datetime2int(self.job_started_at),
575                    )
576                    .decode()
577                    .split("\n")
578                )
579
580                # Process logs
581                logs_stdout = [x for x in logs_stdout if len(x) > 0]
582                logs_stderr = [x for x in logs_stderr if len(x) > 0]
583                logs = []
584                for line in logs_stdout + logs_stderr:
585                    date, line = line.split(" ", 1)
586                    date_obj = (
587                        dateutil.parser.parse(date)
588                        .astimezone(datetime.timezone.utc)
589                        .replace(tzinfo=None)
590                    )
591                    date = unix_time_millis(date_obj)
592                    logs.append({"timestamp": date, "message": line.strip()})
593
594                # Send to cloudwatch
595                log_group = "/aws/batch/job"
596                stream_name = "{0}/default/{1}".format(
597                    self.job_definition.name, self.job_id
598                )
599                self.log_stream_name = stream_name
600                self._log_backend.ensure_log_group(log_group, None)
601                self._log_backend.create_log_stream(log_group, stream_name)
602                self._log_backend.put_log_events(log_group, stream_name, logs, None)
603
604                result = container.wait() or {}
605                job_failed = self.stop or result.get("StatusCode", 0) > 0
606                self._mark_stopped(success=not job_failed)
607
608            except Exception as err:
609                logger.error(
610                    "Failed to run AWS Batch container {0}. Error {1}".format(
611                        self.name, err
612                    )
613                )
614                self._mark_stopped(success=False)
615                container.kill()
616            finally:
617                container.remove()
618        except Exception as err:
619            logger.error(
620                "Failed to run AWS Batch container {0}. Error {1}".format(
621                    self.name, err
622                )
623            )
624            self._mark_stopped(success=False)
625
626    def _mark_stopped(self, success=True):
627        # Ensure that job_stopped/job_stopped_at-attributes are set first
628        # The describe-method needs them immediately when job_state is set
629        self.job_stopped = True
630        self.job_stopped_at = datetime.datetime.now()
631        self.job_state = "SUCCEEDED" if success else "FAILED"
632
633    def terminate(self, reason):
634        if not self.stop:
635            self.stop = True
636            self.job_stopped_reason = reason
637
638    def _wait_for_dependencies(self):
639        dependent_ids = [dependency["jobId"] for dependency in self.depends_on]
640        successful_dependencies = set()
641        while len(successful_dependencies) != len(dependent_ids):
642            for dependent_id in dependent_ids:
643                if dependent_id in self.all_jobs:
644                    dependent_job = self.all_jobs[dependent_id]
645                    if dependent_job.job_state == "SUCCEEDED":
646                        successful_dependencies.add(dependent_id)
647                    if dependent_job.job_state == "FAILED":
648                        logger.error(
649                            "Terminating job {0} due to failed dependency {1}".format(
650                                self.name, dependent_job.name
651                            )
652                        )
653                        self._mark_stopped(success=False)
654                        return False
655
656            time.sleep(1)
657
658        return True
659
660
661class BatchBackend(BaseBackend):
662    def __init__(self, region_name=None):
663        super(BatchBackend, self).__init__()
664        self.region_name = region_name
665
666        self._compute_environments = {}
667        self._job_queues = {}
668        self._job_definitions = {}
669        self._jobs = {}
670
671    @property
672    def iam_backend(self):
673        """
674        :return: IAM Backend
675        :rtype: moto.iam.models.IAMBackend
676        """
677        return iam_backends["global"]
678
679    @property
680    def ec2_backend(self):
681        """
682        :return: EC2 Backend
683        :rtype: moto.ec2.models.EC2Backend
684        """
685        return ec2_backends[self.region_name]
686
687    @property
688    def ecs_backend(self):
689        """
690        :return: ECS Backend
691        :rtype: moto.ecs.models.EC2ContainerServiceBackend
692        """
693        return ecs_backends[self.region_name]
694
695    @property
696    def logs_backend(self):
697        """
698        :return: ECS Backend
699        :rtype: moto.logs.models.LogsBackend
700        """
701        return logs_backends[self.region_name]
702
703    def reset(self):
704        region_name = self.region_name
705
706        for job in self._jobs.values():
707            if job.job_state not in ("FAILED", "SUCCEEDED"):
708                job.stop = True
709                # Try to join
710                job.join(0.2)
711
712        self.__dict__ = {}
713        self.__init__(region_name)
714
715    def get_compute_environment_by_arn(self, arn):
716        return self._compute_environments.get(arn)
717
718    def get_compute_environment_by_name(self, name):
719        for comp_env in self._compute_environments.values():
720            if comp_env.name == name:
721                return comp_env
722        return None
723
724    def get_compute_environment(self, identifier):
725        """
726        Get compute environment by name or ARN
727        :param identifier: Name or ARN
728        :type identifier: str
729
730        :return: Compute Environment or None
731        :rtype: ComputeEnvironment or None
732        """
733        env = self.get_compute_environment_by_arn(identifier)
734        if env is None:
735            env = self.get_compute_environment_by_name(identifier)
736        return env
737
738    def get_job_queue_by_arn(self, arn):
739        return self._job_queues.get(arn)
740
741    def get_job_queue_by_name(self, name):
742        for comp_env in self._job_queues.values():
743            if comp_env.name == name:
744                return comp_env
745        return None
746
747    def get_job_queue(self, identifier):
748        """
749        Get job queue by name or ARN
750        :param identifier: Name or ARN
751        :type identifier: str
752
753        :return: Job Queue or None
754        :rtype: JobQueue or None
755        """
756        env = self.get_job_queue_by_arn(identifier)
757        if env is None:
758            env = self.get_job_queue_by_name(identifier)
759        return env
760
761    def get_job_definition_by_arn(self, arn):
762        return self._job_definitions.get(arn)
763
764    def get_job_definition_by_name(self, name):
765        latest_revision = -1
766        latest_job = None
767        for job_def in self._job_definitions.values():
768            if job_def.name == name and job_def.revision > latest_revision:
769                latest_job = job_def
770                latest_revision = job_def.revision
771        return latest_job
772
773    def get_job_definition_by_name_revision(self, name, revision):
774        for job_def in self._job_definitions.values():
775            if job_def.name == name and job_def.revision == int(revision):
776                return job_def
777        return None
778
779    def get_job_definition(self, identifier):
780        """
781        Get job definitions by name or ARN
782        :param identifier: Name or ARN
783        :type identifier: str
784
785        :return: Job definition or None
786        :rtype: JobDefinition or None
787        """
788        job_def = self.get_job_definition_by_arn(identifier)
789        if job_def is None:
790            if ":" in identifier:
791                job_def = self.get_job_definition_by_name_revision(
792                    *identifier.split(":", 1)
793                )
794            else:
795                job_def = self.get_job_definition_by_name(identifier)
796        return job_def
797
798    def get_job_definitions(self, identifier):
799        """
800        Get job definitions by name or ARN
801        :param identifier: Name or ARN
802        :type identifier: str
803
804        :return: Job definition or None
805        :rtype: list of JobDefinition
806        """
807        result = []
808        env = self.get_job_definition_by_arn(identifier)
809        if env is not None:
810            result.append(env)
811        else:
812            for value in self._job_definitions.values():
813                if value.name == identifier:
814                    result.append(value)
815
816        return result
817
818    def get_job_by_id(self, identifier):
819        """
820        Get job by id
821        :param identifier: Job ID
822        :type identifier: str
823
824        :return: Job
825        :rtype: Job
826        """
827        try:
828            return self._jobs[identifier]
829        except KeyError:
830            return None
831
832    def describe_compute_environments(
833        self, environments=None, max_results=None, next_token=None
834    ):
835        envs = set()
836        if environments is not None:
837            envs = set(environments)
838
839        result = []
840        for arn, environment in self._compute_environments.items():
841            # Filter shortcut
842            if len(envs) > 0 and arn not in envs and environment.name not in envs:
843                continue
844
845            json_part = {
846                "computeEnvironmentArn": arn,
847                "computeEnvironmentName": environment.name,
848                "ecsClusterArn": environment.ecs_arn,
849                "serviceRole": environment.service_role,
850                "state": environment.state,
851                "type": environment.env_type,
852                "status": "VALID",
853                "statusReason": "Compute environment is available",
854            }
855            if environment.env_type == "MANAGED":
856                json_part["computeResources"] = environment.compute_resources
857
858            result.append(json_part)
859
860        return result
861
862    def create_compute_environment(
863        self, compute_environment_name, _type, state, compute_resources, service_role
864    ):
865        # Validate
866        if COMPUTE_ENVIRONMENT_NAME_REGEX.match(compute_environment_name) is None:
867            raise InvalidParameterValueException(
868                "Compute environment name does not match ^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$"
869            )
870
871        if self.get_compute_environment_by_name(compute_environment_name) is not None:
872            raise InvalidParameterValueException(
873                "A compute environment already exists with the name {0}".format(
874                    compute_environment_name
875                )
876            )
877
878        # Look for IAM role
879        try:
880            self.iam_backend.get_role_by_arn(service_role)
881        except IAMNotFoundException:
882            raise InvalidParameterValueException(
883                "Could not find IAM role {0}".format(service_role)
884            )
885
886        if _type not in ("MANAGED", "UNMANAGED"):
887            raise InvalidParameterValueException(
888                "type {0} must be one of MANAGED | UNMANAGED".format(service_role)
889            )
890
891        if state is not None and state not in ("ENABLED", "DISABLED"):
892            raise InvalidParameterValueException(
893                "state {0} must be one of ENABLED | DISABLED".format(state)
894            )
895
896        if compute_resources is None and _type == "MANAGED":
897            raise InvalidParameterValueException(
898                "computeResources must be specified when creating a {0} environment".format(
899                    state
900                )
901            )
902        elif compute_resources is not None:
903            self._validate_compute_resources(compute_resources)
904
905        # By here, all values except SPOT ones have been validated
906        new_comp_env = ComputeEnvironment(
907            compute_environment_name,
908            _type,
909            state,
910            compute_resources,
911            service_role,
912            region_name=self.region_name,
913        )
914        self._compute_environments[new_comp_env.arn] = new_comp_env
915
916        # Ok by this point, everything is legit, so if its Managed then start some instances
917        if _type == "MANAGED" and "FARGATE" not in compute_resources["type"]:
918            cpus = int(
919                compute_resources.get("desiredvCpus", compute_resources["minvCpus"])
920            )
921            instance_types = compute_resources["instanceTypes"]
922            needed_instance_types = self.find_min_instances_to_meet_vcpus(
923                instance_types, cpus
924            )
925            # Create instances
926
927            # Will loop over and over so we get decent subnet coverage
928            subnet_cycle = cycle(compute_resources["subnets"])
929
930            for instance_type in needed_instance_types:
931                reservation = self.ec2_backend.add_instances(
932                    image_id="ami-03cf127a",  # Todo import AMIs
933                    count=1,
934                    user_data=None,
935                    security_group_names=[],
936                    instance_type=instance_type,
937                    region_name=self.region_name,
938                    subnet_id=next(subnet_cycle),
939                    key_name=compute_resources.get("ec2KeyPair", "AWS_OWNED"),
940                    security_group_ids=compute_resources["securityGroupIds"],
941                )
942
943                new_comp_env.add_instance(reservation.instances[0])
944
945        # Create ECS cluster
946        # Should be of format P2OnDemand_Batch_UUID
947        cluster_name = "OnDemand_Batch_" + str(uuid.uuid4())
948        ecs_cluster = self.ecs_backend.create_cluster(cluster_name)
949        new_comp_env.set_ecs(ecs_cluster.arn, cluster_name)
950
951        return compute_environment_name, new_comp_env.arn
952
953    def _validate_compute_resources(self, cr):
954        """
955        Checks contents of sub dictionary for managed clusters
956
957        :param cr: computeResources
958        :type cr: dict
959        """
960        if int(cr["maxvCpus"]) < 0:
961            raise InvalidParameterValueException("maxVCpus must be positive")
962        if "FARGATE" not in cr["type"]:
963            # Most parameters are not applicable to jobs that are running on Fargate resources:
964            # non exhaustive list: minvCpus, instanceTypes, imageId, ec2KeyPair, instanceRole, tags
965            for profile in self.iam_backend.get_instance_profiles():
966                if profile.arn == cr["instanceRole"]:
967                    break
968            else:
969                raise InvalidParameterValueException(
970                    "could not find instanceRole {0}".format(cr["instanceRole"])
971                )
972
973            if int(cr["minvCpus"]) < 0:
974                raise InvalidParameterValueException("minvCpus must be positive")
975            if int(cr["maxvCpus"]) < int(cr["minvCpus"]):
976                raise InvalidParameterValueException(
977                    "maxVCpus must be greater than minvCpus"
978                )
979
980            if len(cr["instanceTypes"]) == 0:
981                raise InvalidParameterValueException(
982                    "At least 1 instance type must be provided"
983                )
984            for instance_type in cr["instanceTypes"]:
985                if instance_type == "optimal":
986                    pass  # Optimal should pick from latest of current gen
987                elif instance_type not in EC2_INSTANCE_TYPES:
988                    raise InvalidParameterValueException(
989                        "Instance type {0} does not exist".format(instance_type)
990                    )
991
992        for sec_id in cr["securityGroupIds"]:
993            if self.ec2_backend.get_security_group_from_id(sec_id) is None:
994                raise InvalidParameterValueException(
995                    "security group {0} does not exist".format(sec_id)
996                )
997        if len(cr["securityGroupIds"]) == 0:
998            raise InvalidParameterValueException(
999                "At least 1 security group must be provided"
1000            )
1001
1002        for subnet_id in cr["subnets"]:
1003            try:
1004                self.ec2_backend.get_subnet(subnet_id)
1005            except InvalidSubnetIdError:
1006                raise InvalidParameterValueException(
1007                    "subnet {0} does not exist".format(subnet_id)
1008                )
1009        if len(cr["subnets"]) == 0:
1010            raise InvalidParameterValueException("At least 1 subnet must be provided")
1011
1012        if cr["type"] not in {"EC2", "SPOT", "FARGATE", "FARGATE_SPOT"}:
1013            raise InvalidParameterValueException(
1014                "computeResources.type must be either EC2 | SPOT | FARGATE | FARGATE_SPOT"
1015            )
1016
1017    @staticmethod
1018    def find_min_instances_to_meet_vcpus(instance_types, target):
1019        """
1020        Finds the minimum needed instances to meed a vcpu target
1021
1022        :param instance_types: Instance types, like ['t2.medium', 't2.small']
1023        :type instance_types: list of str
1024        :param target: VCPU target
1025        :type target: float
1026        :return: List of instance types
1027        :rtype: list of str
1028        """
1029        # vcpus = [ (vcpus, instance_type), (vcpus, instance_type), ... ]
1030        instance_vcpus = []
1031        instances = []
1032
1033        for instance_type in instance_types:
1034            if instance_type == "optimal":
1035                instance_type = "m4.4xlarge"
1036
1037            instance_vcpus.append(
1038                (
1039                    EC2_INSTANCE_TYPES[instance_type]["VCpuInfo"]["DefaultVCpus"],
1040                    instance_type,
1041                )
1042            )
1043
1044        instance_vcpus = sorted(instance_vcpus, key=lambda item: item[0], reverse=True)
1045        # Loop through,
1046        #   if biggest instance type smaller than target, and len(instance_types)> 1, then use biggest type
1047        #   if biggest instance type bigger than target, and len(instance_types)> 1, then remove it and move on
1048
1049        #   if biggest instance type bigger than target and len(instan_types) == 1 then add instance and finish
1050        #   if biggest instance type smaller than target and len(instan_types) == 1 then loop adding instances until target == 0
1051        #   ^^ boils down to keep adding last till target vcpus is negative
1052        #   #Algorithm ;-) ... Could probably be done better with some quality lambdas
1053        while target > 0:
1054            current_vcpu, current_instance = instance_vcpus[0]
1055
1056            if len(instance_vcpus) > 1:
1057                if current_vcpu <= target:
1058                    target -= current_vcpu
1059                    instances.append(current_instance)
1060                else:
1061                    # try next biggest instance
1062                    instance_vcpus.pop(0)
1063            else:
1064                # Were on the last instance
1065                target -= current_vcpu
1066                instances.append(current_instance)
1067
1068        return instances
1069
1070    def delete_compute_environment(self, compute_environment_name):
1071        if compute_environment_name is None:
1072            raise InvalidParameterValueException("Missing computeEnvironment parameter")
1073
1074        compute_env = self.get_compute_environment(compute_environment_name)
1075
1076        if compute_env is not None:
1077            # Pop ComputeEnvironment
1078            self._compute_environments.pop(compute_env.arn)
1079
1080            # Delete ECS cluster
1081            self.ecs_backend.delete_cluster(compute_env.ecs_name)
1082
1083            if compute_env.env_type == "MANAGED":
1084                # Delete compute environment
1085                instance_ids = [instance.id for instance in compute_env.instances]
1086                if instance_ids:
1087                    self.ec2_backend.terminate_instances(instance_ids)
1088
1089    def update_compute_environment(
1090        self, compute_environment_name, state, compute_resources, service_role
1091    ):
1092        # Validate
1093        compute_env = self.get_compute_environment(compute_environment_name)
1094        if compute_env is None:
1095            raise ClientException("Compute environment {0} does not exist")
1096
1097        # Look for IAM role
1098        if service_role is not None:
1099            try:
1100                role = self.iam_backend.get_role_by_arn(service_role)
1101            except IAMNotFoundException:
1102                raise InvalidParameterValueException(
1103                    "Could not find IAM role {0}".format(service_role)
1104                )
1105
1106            compute_env.service_role = role
1107
1108        if state is not None:
1109            if state not in ("ENABLED", "DISABLED"):
1110                raise InvalidParameterValueException(
1111                    "state {0} must be one of ENABLED | DISABLED".format(state)
1112                )
1113
1114            compute_env.state = state
1115
1116        if compute_resources is not None:
1117            # TODO Implement resizing of instances based on changing vCpus
1118            # compute_resources CAN contain desiredvCpus, maxvCpus, minvCpus, and can contain none of them.
1119            pass
1120
1121        return compute_env.name, compute_env.arn
1122
1123    def create_job_queue(self, queue_name, priority, state, compute_env_order):
1124        """
1125        Create a job queue
1126
1127        :param queue_name: Queue name
1128        :type queue_name: str
1129        :param priority: Queue priority
1130        :type priority: int
1131        :param state: Queue state
1132        :type state: string
1133        :param compute_env_order: Compute environment list
1134        :type compute_env_order: list of dict
1135        :return: Tuple of Name, ARN
1136        :rtype: tuple of str
1137        """
1138        for variable, var_name in (
1139            (queue_name, "jobQueueName"),
1140            (priority, "priority"),
1141            (state, "state"),
1142            (compute_env_order, "computeEnvironmentOrder"),
1143        ):
1144            if variable is None:
1145                raise ClientException("{0} must be provided".format(var_name))
1146
1147        if state not in ("ENABLED", "DISABLED"):
1148            raise ClientException(
1149                "state {0} must be one of ENABLED | DISABLED".format(state)
1150            )
1151        if self.get_job_queue_by_name(queue_name) is not None:
1152            raise ClientException("Job queue {0} already exists".format(queue_name))
1153
1154        if len(compute_env_order) == 0:
1155            raise ClientException("At least 1 compute environment must be provided")
1156        try:
1157            # orders and extracts computeEnvironment names
1158            ordered_compute_environments = [
1159                item["computeEnvironment"]
1160                for item in sorted(compute_env_order, key=lambda x: x["order"])
1161            ]
1162            env_objects = []
1163            # Check each ARN exists, then make a list of compute env's
1164            for arn in ordered_compute_environments:
1165                env = self.get_compute_environment_by_arn(arn)
1166                if env is None:
1167                    raise ClientException(
1168                        "Compute environment {0} does not exist".format(arn)
1169                    )
1170                env_objects.append(env)
1171        except Exception:
1172            raise ClientException("computeEnvironmentOrder is malformed")
1173
1174        # Create new Job Queue
1175        queue = JobQueue(
1176            queue_name,
1177            priority,
1178            state,
1179            env_objects,
1180            compute_env_order,
1181            self.region_name,
1182        )
1183        self._job_queues[queue.arn] = queue
1184
1185        return queue_name, queue.arn
1186
1187    def describe_job_queues(self, job_queues=None, max_results=None, next_token=None):
1188        envs = set()
1189        if job_queues is not None:
1190            envs = set(job_queues)
1191
1192        result = []
1193        for arn, job_queue in self._job_queues.items():
1194            # Filter shortcut
1195            if len(envs) > 0 and arn not in envs and job_queue.name not in envs:
1196                continue
1197
1198            result.append(job_queue.describe())
1199
1200        return result
1201
1202    def update_job_queue(self, queue_name, priority, state, compute_env_order):
1203        """
1204        Update a job queue
1205
1206        :param queue_name: Queue name
1207        :type queue_name: str
1208        :param priority: Queue priority
1209        :type priority: int
1210        :param state: Queue state
1211        :type state: string
1212        :param compute_env_order: Compute environment list
1213        :type compute_env_order: list of dict
1214        :return: Tuple of Name, ARN
1215        :rtype: tuple of str
1216        """
1217        if queue_name is None:
1218            raise ClientException("jobQueueName must be provided")
1219
1220        job_queue = self.get_job_queue(queue_name)
1221        if job_queue is None:
1222            raise ClientException("Job queue {0} does not exist".format(queue_name))
1223
1224        if state is not None:
1225            if state not in ("ENABLED", "DISABLED"):
1226                raise ClientException(
1227                    "state {0} must be one of ENABLED | DISABLED".format(state)
1228                )
1229
1230            job_queue.state = state
1231
1232        if compute_env_order is not None:
1233            if len(compute_env_order) == 0:
1234                raise ClientException("At least 1 compute environment must be provided")
1235            try:
1236                # orders and extracts computeEnvironment names
1237                ordered_compute_environments = [
1238                    item["computeEnvironment"]
1239                    for item in sorted(compute_env_order, key=lambda x: x["order"])
1240                ]
1241                env_objects = []
1242                # Check each ARN exists, then make a list of compute env's
1243                for arn in ordered_compute_environments:
1244                    env = self.get_compute_environment_by_arn(arn)
1245                    if env is None:
1246                        raise ClientException(
1247                            "Compute environment {0} does not exist".format(arn)
1248                        )
1249                    env_objects.append(env)
1250            except Exception:
1251                raise ClientException("computeEnvironmentOrder is malformed")
1252
1253            job_queue.env_order_json = compute_env_order
1254            job_queue.environments = env_objects
1255
1256        if priority is not None:
1257            job_queue.priority = priority
1258
1259        return queue_name, job_queue.arn
1260
1261    def delete_job_queue(self, queue_name):
1262        job_queue = self.get_job_queue(queue_name)
1263
1264        if job_queue is not None:
1265            del self._job_queues[job_queue.arn]
1266
1267    def register_job_definition(
1268        self, def_name, parameters, _type, tags, retry_strategy, container_properties
1269    ):
1270        if def_name is None:
1271            raise ClientException("jobDefinitionName must be provided")
1272
1273        job_def = self.get_job_definition_by_name(def_name)
1274        if retry_strategy is not None:
1275            try:
1276                retry_strategy = retry_strategy["attempts"]
1277            except Exception:
1278                raise ClientException("retryStrategy is malformed")
1279        if job_def is None:
1280            if not tags:
1281                tags = {}
1282            job_def = JobDefinition(
1283                def_name,
1284                parameters,
1285                _type,
1286                container_properties,
1287                tags=tags,
1288                region_name=self.region_name,
1289                retry_strategy=retry_strategy,
1290            )
1291        else:
1292            # Make new jobdef
1293            job_def = job_def.update(
1294                parameters, _type, container_properties, retry_strategy
1295            )
1296
1297        self._job_definitions[job_def.arn] = job_def
1298
1299        return def_name, job_def.arn, job_def.revision
1300
1301    def deregister_job_definition(self, def_name):
1302        job_def = self.get_job_definition_by_arn(def_name)
1303        if job_def is None and ":" in def_name:
1304            name, revision = def_name.split(":", 1)
1305            job_def = self.get_job_definition_by_name_revision(name, revision)
1306
1307        if job_def is not None:
1308            del self._job_definitions[job_def.arn]
1309
1310    def describe_job_definitions(
1311        self,
1312        job_def_name=None,
1313        job_def_list=None,
1314        status=None,
1315        max_results=None,
1316        next_token=None,
1317    ):
1318        jobs = []
1319
1320        # As a job name can reference multiple revisions, we get a list of them
1321        if job_def_name is not None:
1322            job_def = self.get_job_definitions(job_def_name)
1323            if job_def is not None:
1324                jobs.extend(job_def)
1325        elif job_def_list is not None:
1326            for job in job_def_list:
1327                job_def = self.get_job_definitions(job)
1328                if job_def is not None:
1329                    jobs.extend(job_def)
1330        else:
1331            jobs.extend(self._job_definitions.values())
1332
1333        # Got all the job defs were after, filter then by status
1334        if status is not None:
1335            return [job for job in jobs if job.status == status]
1336        for job in jobs:
1337            job.describe()
1338        return jobs
1339
1340    def submit_job(
1341        self,
1342        job_name,
1343        job_def_id,
1344        job_queue,
1345        parameters=None,
1346        retries=None,
1347        depends_on=None,
1348        container_overrides=None,
1349    ):
1350        # TODO parameters, retries (which is a dict raw from request), job dependencies and container overrides are ignored for now
1351
1352        # Look for job definition
1353        job_def = self.get_job_definition(job_def_id)
1354        if job_def is None:
1355            raise ClientException(
1356                "Job definition {0} does not exist".format(job_def_id)
1357            )
1358
1359        queue = self.get_job_queue(job_queue)
1360        if queue is None:
1361            raise ClientException("Job queue {0} does not exist".format(job_queue))
1362
1363        job = Job(
1364            job_name,
1365            job_def,
1366            queue,
1367            log_backend=self.logs_backend,
1368            container_overrides=container_overrides,
1369            depends_on=depends_on,
1370            all_jobs=self._jobs,
1371        )
1372        self._jobs[job.job_id] = job
1373
1374        # Here comes the fun
1375        job.start()
1376
1377        return job_name, job.job_id
1378
1379    def describe_jobs(self, jobs):
1380        job_filter = set()
1381        if jobs is not None:
1382            job_filter = set(jobs)
1383
1384        result = []
1385        for key, job in self._jobs.items():
1386            if len(job_filter) > 0 and key not in job_filter:
1387                continue
1388
1389            result.append(job.describe())
1390
1391        return result
1392
1393    def list_jobs(self, job_queue, job_status=None, max_results=None, next_token=None):
1394        jobs = []
1395
1396        job_queue = self.get_job_queue(job_queue)
1397        if job_queue is None:
1398            raise ClientException("Job queue {0} does not exist".format(job_queue))
1399
1400        if job_status is not None and job_status not in (
1401            "SUBMITTED",
1402            "PENDING",
1403            "RUNNABLE",
1404            "STARTING",
1405            "RUNNING",
1406            "SUCCEEDED",
1407            "FAILED",
1408        ):
1409            raise ClientException(
1410                "Job status is not one of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED"
1411            )
1412
1413        for job in job_queue.jobs:
1414            if job_status is not None and job.job_state != job_status:
1415                continue
1416
1417            jobs.append(job)
1418
1419        return jobs
1420
1421    def cancel_job(self, job_id, reason):
1422        job = self.get_job_by_id(job_id)
1423        if job.job_state in ["SUBMITTED", "PENDING", "RUNNABLE"]:
1424            job.terminate(reason)
1425        # No-Op for jobs that have already started - user has to explicitly terminate those
1426
1427    def terminate_job(self, job_id, reason):
1428        if job_id is None:
1429            raise ClientException("Job ID does not exist")
1430        if reason is None:
1431            raise ClientException("Reason does not exist")
1432
1433        job = self.get_job_by_id(job_id)
1434        if job is None:
1435            raise ClientException("Job not found")
1436
1437        job.terminate(reason)
1438
1439
1440batch_backends = {}
1441for region in Session().get_available_regions("batch"):
1442    batch_backends[region] = BatchBackend(region)
1443for region in Session().get_available_regions("batch", partition_name="aws-us-gov"):
1444    batch_backends[region] = BatchBackend(region)
1445for region in Session().get_available_regions("batch", partition_name="aws-cn"):
1446    batch_backends[region] = BatchBackend(region)
1447