1import re 2from itertools import cycle 3import datetime 4import time 5import uuid 6import logging 7import docker 8import threading 9import dateutil.parser 10from boto3 import Session 11 12from moto.core import BaseBackend, BaseModel, CloudFormationModel 13from moto.iam import iam_backends 14from moto.ec2 import ec2_backends 15from moto.ecs import ecs_backends 16from moto.logs import logs_backends 17 18from .exceptions import InvalidParameterValueException, ClientException, ValidationError 19from .utils import ( 20 make_arn_for_compute_env, 21 make_arn_for_job_queue, 22 make_arn_for_task_def, 23 lowercase_first_key, 24) 25from moto.ec2.exceptions import InvalidSubnetIdError 26from moto.ec2.models import INSTANCE_TYPES as EC2_INSTANCE_TYPES 27from moto.iam.exceptions import IAMNotFoundException 28from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID 29from moto.core.utils import unix_time_millis 30from moto.utilities.docker_utilities import DockerModel, parse_image_ref 31from ..utilities.tagging_service import TaggingService 32 33logger = logging.getLogger(__name__) 34COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile( 35 r"^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$" 36) 37 38 39def datetime2int_milliseconds(date): 40 """ 41 AWS returns timestamps in milliseconds 42 We don't use milliseconds timestamps internally, 43 this method should be used only in describe() method 44 """ 45 return int(date.timestamp() * 1000) 46 47 48def datetime2int(date): 49 return int(time.mktime(date.timetuple())) 50 51 52class ComputeEnvironment(CloudFormationModel): 53 def __init__( 54 self, 55 compute_environment_name, 56 _type, 57 state, 58 compute_resources, 59 service_role, 60 region_name, 61 ): 62 self.name = compute_environment_name 63 self.env_type = _type 64 self.state = state 65 self.compute_resources = compute_resources 66 self.service_role = service_role 67 self.arn = make_arn_for_compute_env( 68 DEFAULT_ACCOUNT_ID, compute_environment_name, region_name 69 ) 70 71 self.instances = [] 72 self.ecs_arn = None 73 self.ecs_name = None 74 75 def add_instance(self, instance): 76 self.instances.append(instance) 77 78 def set_ecs(self, arn, name): 79 self.ecs_arn = arn 80 self.ecs_name = name 81 82 @property 83 def physical_resource_id(self): 84 return self.arn 85 86 @staticmethod 87 def cloudformation_name_type(): 88 return "ComputeEnvironmentName" 89 90 @staticmethod 91 def cloudformation_type(): 92 # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-computeenvironment.html 93 return "AWS::Batch::ComputeEnvironment" 94 95 @classmethod 96 def create_from_cloudformation_json( 97 cls, resource_name, cloudformation_json, region_name, **kwargs 98 ): 99 backend = batch_backends[region_name] 100 properties = cloudformation_json["Properties"] 101 102 env = backend.create_compute_environment( 103 resource_name, 104 properties["Type"], 105 properties.get("State", "ENABLED"), 106 lowercase_first_key(properties["ComputeResources"]), 107 properties["ServiceRole"], 108 ) 109 arn = env[1] 110 111 return backend.get_compute_environment_by_arn(arn) 112 113 114class JobQueue(CloudFormationModel): 115 def __init__( 116 self, name, priority, state, environments, env_order_json, region_name 117 ): 118 """ 119 :param name: Job queue name 120 :type name: str 121 :param priority: Job queue priority 122 :type priority: int 123 :param state: Either ENABLED or DISABLED 124 :type state: str 125 :param environments: Compute Environments 126 :type environments: list of ComputeEnvironment 127 :param env_order_json: Compute Environments JSON for use when describing 128 :type env_order_json: list of dict 129 :param region_name: Region name 130 :type region_name: str 131 """ 132 self.name = name 133 self.priority = priority 134 self.state = state 135 self.environments = environments 136 self.env_order_json = env_order_json 137 self.arn = make_arn_for_job_queue(DEFAULT_ACCOUNT_ID, name, region_name) 138 self.status = "VALID" 139 140 self.jobs = [] 141 142 def describe(self): 143 result = { 144 "computeEnvironmentOrder": self.env_order_json, 145 "jobQueueArn": self.arn, 146 "jobQueueName": self.name, 147 "priority": self.priority, 148 "state": self.state, 149 "status": self.status, 150 } 151 152 return result 153 154 @property 155 def physical_resource_id(self): 156 return self.arn 157 158 @staticmethod 159 def cloudformation_name_type(): 160 return "JobQueueName" 161 162 @staticmethod 163 def cloudformation_type(): 164 # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobqueue.html 165 return "AWS::Batch::JobQueue" 166 167 @classmethod 168 def create_from_cloudformation_json( 169 cls, resource_name, cloudformation_json, region_name, **kwargs 170 ): 171 backend = batch_backends[region_name] 172 properties = cloudformation_json["Properties"] 173 174 # Need to deal with difference case from cloudformation compute_resources, e.g. instanceRole vs InstanceRole 175 # Hacky fix to normalise keys, is making me think I want to start spamming cAsEiNsEnSiTiVe dictionaries 176 compute_envs = [ 177 lowercase_first_key(dict_item) 178 for dict_item in properties["ComputeEnvironmentOrder"] 179 ] 180 181 queue = backend.create_job_queue( 182 queue_name=resource_name, 183 priority=properties["Priority"], 184 state=properties.get("State", "ENABLED"), 185 compute_env_order=compute_envs, 186 ) 187 arn = queue[1] 188 189 return backend.get_job_queue_by_arn(arn) 190 191 192class JobDefinition(CloudFormationModel): 193 def __init__( 194 self, 195 name, 196 parameters, 197 _type, 198 container_properties, 199 region_name, 200 tags={}, 201 revision=0, 202 retry_strategy=0, 203 ): 204 self.name = name 205 self.retries = retry_strategy 206 self.type = _type 207 self.revision = revision 208 self._region = region_name 209 self.container_properties = container_properties 210 self.arn = None 211 self.status = "ACTIVE" 212 self.tagger = TaggingService() 213 if parameters is None: 214 parameters = {} 215 self.parameters = parameters 216 217 self._validate() 218 self._update_arn() 219 220 tags = self._format_tags(tags) 221 # Validate the tags before proceeding. 222 errmsg = self.tagger.validate_tags(tags or []) 223 if errmsg: 224 raise ValidationError(errmsg) 225 226 self.tagger.tag_resource(self.arn, tags or []) 227 228 def _format_tags(self, tags): 229 return [{"Key": k, "Value": v} for k, v in tags.items()] 230 231 def _update_arn(self): 232 self.revision += 1 233 self.arn = make_arn_for_task_def( 234 DEFAULT_ACCOUNT_ID, self.name, self.revision, self._region 235 ) 236 237 def _get_resource_requirement(self, req_type, default=None): 238 """ 239 Get resource requirement from container properties. 240 241 Resource requirements like "memory" and "vcpus" are now specified in 242 "resourceRequirements". This function retrieves a resource requirement 243 from either container_properties.resourceRequirements (preferred) or 244 directly from container_properties (deprecated). 245 246 :param req_type: The type of resource requirement to retrieve. 247 :type req_type: ["gpu", "memory", "vcpus"] 248 249 :param default: The default value to return if the resource requirement is not found. 250 :type default: any, default=None 251 252 :return: The value of the resource requirement, or None. 253 :rtype: any 254 """ 255 resource_reqs = self.container_properties.get("resourceRequirements", []) 256 257 # Filter the resource requirements by the specified type. 258 # Note that VCPUS are specified in resourceRequirements without the 259 # trailing "s", so we strip that off in the comparison below. 260 required_resource = list( 261 filter( 262 lambda req: req["type"].lower() == req_type.lower().rstrip("s"), 263 resource_reqs, 264 ) 265 ) 266 267 if required_resource: 268 return required_resource[0]["value"] 269 else: 270 return self.container_properties.get(req_type, default) 271 272 def _validate(self): 273 if self.type not in ("container",): 274 raise ClientException('type must be one of "container"') 275 276 # For future use when containers arnt the only thing in batch 277 if self.type != "container": 278 raise NotImplementedError() 279 280 if not isinstance(self.parameters, dict): 281 raise ClientException("parameters must be a string to string map") 282 283 if "image" not in self.container_properties: 284 raise ClientException("containerProperties must contain image") 285 286 memory = self._get_resource_requirement("memory") 287 if memory is None: 288 raise ClientException("containerProperties must contain memory") 289 if memory < 4: 290 raise ClientException("container memory limit must be greater than 4") 291 292 vcpus = self._get_resource_requirement("vcpus") 293 if vcpus is None: 294 raise ClientException("containerProperties must contain vcpus") 295 if vcpus < 1: 296 raise ClientException("container vcpus limit must be greater than 0") 297 298 def update(self, parameters, _type, container_properties, retry_strategy): 299 if parameters is None: 300 parameters = self.parameters 301 302 if _type is None: 303 _type = self.type 304 305 if container_properties is None: 306 container_properties = self.container_properties 307 308 if retry_strategy is None: 309 retry_strategy = self.retries 310 311 return JobDefinition( 312 self.name, 313 parameters, 314 _type, 315 container_properties, 316 region_name=self._region, 317 revision=self.revision, 318 retry_strategy=retry_strategy, 319 ) 320 321 def describe(self): 322 result = { 323 "jobDefinitionArn": self.arn, 324 "jobDefinitionName": self.name, 325 "parameters": self.parameters, 326 "revision": self.revision, 327 "status": self.status, 328 "type": self.type, 329 "tags": self.tagger.get_tag_dict_for_resource(self.arn), 330 } 331 if self.container_properties is not None: 332 result["containerProperties"] = self.container_properties 333 if self.retries is not None and self.retries > 0: 334 result["retryStrategy"] = {"attempts": self.retries} 335 336 return result 337 338 @property 339 def physical_resource_id(self): 340 return self.arn 341 342 @staticmethod 343 def cloudformation_name_type(): 344 return "JobDefinitionName" 345 346 @staticmethod 347 def cloudformation_type(): 348 # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-batch-jobdefinition.html 349 return "AWS::Batch::JobDefinition" 350 351 @classmethod 352 def create_from_cloudformation_json( 353 cls, resource_name, cloudformation_json, region_name, **kwargs 354 ): 355 backend = batch_backends[region_name] 356 properties = cloudformation_json["Properties"] 357 res = backend.register_job_definition( 358 def_name=resource_name, 359 parameters=lowercase_first_key(properties.get("Parameters", {})), 360 _type="container", 361 tags=lowercase_first_key(properties.get("Tags", {})), 362 retry_strategy=lowercase_first_key(properties["RetryStrategy"]), 363 container_properties=lowercase_first_key(properties["ContainerProperties"]), 364 ) 365 arn = res[1] 366 367 return backend.get_job_definition_by_arn(arn) 368 369 370class Job(threading.Thread, BaseModel, DockerModel): 371 def __init__( 372 self, 373 name, 374 job_def, 375 job_queue, 376 log_backend, 377 container_overrides, 378 depends_on, 379 all_jobs, 380 ): 381 """ 382 Docker Job 383 384 :param name: Job Name 385 :param job_def: Job definition 386 :type: job_def: JobDefinition 387 :param job_queue: Job Queue 388 :param log_backend: Log backend 389 :type log_backend: moto.logs.models.LogsBackend 390 """ 391 threading.Thread.__init__(self) 392 DockerModel.__init__(self) 393 394 self.job_name = name 395 self.job_id = str(uuid.uuid4()) 396 self.job_definition = job_def 397 self.container_overrides = container_overrides or {} 398 self.job_queue = job_queue 399 self.job_state = "SUBMITTED" # One of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED 400 self.job_queue.jobs.append(self) 401 self.job_created_at = datetime.datetime.now() 402 self.job_started_at = datetime.datetime(1970, 1, 1) 403 self.job_stopped_at = datetime.datetime(1970, 1, 1) 404 self.job_stopped = False 405 self.job_stopped_reason = None 406 self.depends_on = depends_on 407 self.all_jobs = all_jobs 408 409 self.stop = False 410 411 self.daemon = True 412 self.name = "MOTO-BATCH-" + self.job_id 413 414 self._log_backend = log_backend 415 self.log_stream_name = None 416 417 def describe(self): 418 result = { 419 "jobDefinition": self.job_definition.arn, 420 "jobId": self.job_id, 421 "jobName": self.job_name, 422 "jobQueue": self.job_queue.arn, 423 "status": self.job_state, 424 "dependsOn": self.depends_on if self.depends_on else [], 425 "createdAt": datetime2int_milliseconds(self.job_created_at), 426 } 427 if result["status"] not in ["SUBMITTED", "PENDING", "RUNNABLE", "STARTING"]: 428 result["startedAt"] = datetime2int_milliseconds(self.job_started_at) 429 if self.job_stopped: 430 result["stoppedAt"] = datetime2int_milliseconds(self.job_stopped_at) 431 result["container"] = {} 432 result["container"]["command"] = self._get_container_property("command", []) 433 result["container"]["privileged"] = self._get_container_property( 434 "privileged", False 435 ) 436 result["container"][ 437 "readonlyRootFilesystem" 438 ] = self._get_container_property("readonlyRootFilesystem", False) 439 result["container"]["ulimits"] = self._get_container_property("ulimits", {}) 440 result["container"]["vcpus"] = self._get_container_property("vcpus", 1) 441 result["container"]["memory"] = self._get_container_property("memory", 512) 442 result["container"]["volumes"] = self._get_container_property("volumes", []) 443 result["container"]["environment"] = self._get_container_property( 444 "environment", [] 445 ) 446 result["container"]["logStreamName"] = self.log_stream_name 447 if self.job_stopped_reason is not None: 448 result["statusReason"] = self.job_stopped_reason 449 return result 450 451 def _get_container_property(self, p, default): 452 if p == "environment": 453 job_env = self.container_overrides.get(p, default) 454 jd_env = self.job_definition.container_properties.get(p, default) 455 456 job_env_dict = {_env["name"]: _env["value"] for _env in job_env} 457 jd_env_dict = {_env["name"]: _env["value"] for _env in jd_env} 458 459 for key in jd_env_dict.keys(): 460 if key not in job_env_dict.keys(): 461 job_env.append({"name": key, "value": jd_env_dict[key]}) 462 463 job_env.append({"name": "AWS_BATCH_JOB_ID", "value": self.job_id}) 464 465 return job_env 466 467 if p in ["vcpus", "memory"]: 468 return self.container_overrides.get( 469 p, self.job_definition._get_resource_requirement(p, default) 470 ) 471 472 return self.container_overrides.get( 473 p, self.job_definition.container_properties.get(p, default) 474 ) 475 476 def run(self): 477 """ 478 Run the container. 479 480 Logic is as follows: 481 Generate container info (eventually from task definition) 482 Start container 483 Loop whilst not asked to stop and the container is running. 484 Get all logs from container between the last time I checked and now. 485 Convert logs into cloudwatch format 486 Put logs into cloudwatch 487 488 :return: 489 """ 490 try: 491 self.job_state = "PENDING" 492 493 if self.depends_on and not self._wait_for_dependencies(): 494 return 495 496 image = self.job_definition.container_properties.get( 497 "image", "alpine:latest" 498 ) 499 privileged = self.job_definition.container_properties.get( 500 "privileged", False 501 ) 502 cmd = self._get_container_property( 503 "command", 504 '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"', 505 ) 506 environment = { 507 e["name"]: e["value"] 508 for e in self._get_container_property("environment", []) 509 } 510 volumes = { 511 v["name"]: v["host"] 512 for v in self._get_container_property("volumes", []) 513 } 514 mounts = [ 515 docker.types.Mount( 516 m["containerPath"], 517 volumes[m["sourceVolume"]]["sourcePath"], 518 type="bind", 519 read_only=m["readOnly"], 520 ) 521 for m in self._get_container_property("mountPoints", []) 522 ] 523 name = "{0}-{1}".format(self.job_name, self.job_id) 524 525 self.job_state = "RUNNABLE" 526 # TODO setup ecs container instance 527 528 self.job_started_at = datetime.datetime.now() 529 530 log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON) 531 image_repository, image_tag = parse_image_ref(image) 532 # avoid explicit pulling here, to allow using cached images 533 # self.docker_client.images.pull(image_repository, image_tag) 534 self.job_state = "STARTING" 535 container = self.docker_client.containers.run( 536 image, 537 cmd, 538 detach=True, 539 name=name, 540 log_config=log_config, 541 environment=environment, 542 mounts=mounts, 543 privileged=privileged, 544 ) 545 self.job_state = "RUNNING" 546 try: 547 container.reload() 548 while container.status == "running" and not self.stop: 549 container.reload() 550 time.sleep(0.5) 551 552 # Container should be stopped by this point... unless asked to stop 553 if container.status == "running": 554 container.kill() 555 556 # Log collection 557 logs_stdout = [] 558 logs_stderr = [] 559 logs_stderr.extend( 560 container.logs( 561 stdout=False, 562 stderr=True, 563 timestamps=True, 564 since=datetime2int(self.job_started_at), 565 ) 566 .decode() 567 .split("\n") 568 ) 569 logs_stdout.extend( 570 container.logs( 571 stdout=True, 572 stderr=False, 573 timestamps=True, 574 since=datetime2int(self.job_started_at), 575 ) 576 .decode() 577 .split("\n") 578 ) 579 580 # Process logs 581 logs_stdout = [x for x in logs_stdout if len(x) > 0] 582 logs_stderr = [x for x in logs_stderr if len(x) > 0] 583 logs = [] 584 for line in logs_stdout + logs_stderr: 585 date, line = line.split(" ", 1) 586 date_obj = ( 587 dateutil.parser.parse(date) 588 .astimezone(datetime.timezone.utc) 589 .replace(tzinfo=None) 590 ) 591 date = unix_time_millis(date_obj) 592 logs.append({"timestamp": date, "message": line.strip()}) 593 594 # Send to cloudwatch 595 log_group = "/aws/batch/job" 596 stream_name = "{0}/default/{1}".format( 597 self.job_definition.name, self.job_id 598 ) 599 self.log_stream_name = stream_name 600 self._log_backend.ensure_log_group(log_group, None) 601 self._log_backend.create_log_stream(log_group, stream_name) 602 self._log_backend.put_log_events(log_group, stream_name, logs, None) 603 604 result = container.wait() or {} 605 job_failed = self.stop or result.get("StatusCode", 0) > 0 606 self._mark_stopped(success=not job_failed) 607 608 except Exception as err: 609 logger.error( 610 "Failed to run AWS Batch container {0}. Error {1}".format( 611 self.name, err 612 ) 613 ) 614 self._mark_stopped(success=False) 615 container.kill() 616 finally: 617 container.remove() 618 except Exception as err: 619 logger.error( 620 "Failed to run AWS Batch container {0}. Error {1}".format( 621 self.name, err 622 ) 623 ) 624 self._mark_stopped(success=False) 625 626 def _mark_stopped(self, success=True): 627 # Ensure that job_stopped/job_stopped_at-attributes are set first 628 # The describe-method needs them immediately when job_state is set 629 self.job_stopped = True 630 self.job_stopped_at = datetime.datetime.now() 631 self.job_state = "SUCCEEDED" if success else "FAILED" 632 633 def terminate(self, reason): 634 if not self.stop: 635 self.stop = True 636 self.job_stopped_reason = reason 637 638 def _wait_for_dependencies(self): 639 dependent_ids = [dependency["jobId"] for dependency in self.depends_on] 640 successful_dependencies = set() 641 while len(successful_dependencies) != len(dependent_ids): 642 for dependent_id in dependent_ids: 643 if dependent_id in self.all_jobs: 644 dependent_job = self.all_jobs[dependent_id] 645 if dependent_job.job_state == "SUCCEEDED": 646 successful_dependencies.add(dependent_id) 647 if dependent_job.job_state == "FAILED": 648 logger.error( 649 "Terminating job {0} due to failed dependency {1}".format( 650 self.name, dependent_job.name 651 ) 652 ) 653 self._mark_stopped(success=False) 654 return False 655 656 time.sleep(1) 657 658 return True 659 660 661class BatchBackend(BaseBackend): 662 def __init__(self, region_name=None): 663 super(BatchBackend, self).__init__() 664 self.region_name = region_name 665 666 self._compute_environments = {} 667 self._job_queues = {} 668 self._job_definitions = {} 669 self._jobs = {} 670 671 @property 672 def iam_backend(self): 673 """ 674 :return: IAM Backend 675 :rtype: moto.iam.models.IAMBackend 676 """ 677 return iam_backends["global"] 678 679 @property 680 def ec2_backend(self): 681 """ 682 :return: EC2 Backend 683 :rtype: moto.ec2.models.EC2Backend 684 """ 685 return ec2_backends[self.region_name] 686 687 @property 688 def ecs_backend(self): 689 """ 690 :return: ECS Backend 691 :rtype: moto.ecs.models.EC2ContainerServiceBackend 692 """ 693 return ecs_backends[self.region_name] 694 695 @property 696 def logs_backend(self): 697 """ 698 :return: ECS Backend 699 :rtype: moto.logs.models.LogsBackend 700 """ 701 return logs_backends[self.region_name] 702 703 def reset(self): 704 region_name = self.region_name 705 706 for job in self._jobs.values(): 707 if job.job_state not in ("FAILED", "SUCCEEDED"): 708 job.stop = True 709 # Try to join 710 job.join(0.2) 711 712 self.__dict__ = {} 713 self.__init__(region_name) 714 715 def get_compute_environment_by_arn(self, arn): 716 return self._compute_environments.get(arn) 717 718 def get_compute_environment_by_name(self, name): 719 for comp_env in self._compute_environments.values(): 720 if comp_env.name == name: 721 return comp_env 722 return None 723 724 def get_compute_environment(self, identifier): 725 """ 726 Get compute environment by name or ARN 727 :param identifier: Name or ARN 728 :type identifier: str 729 730 :return: Compute Environment or None 731 :rtype: ComputeEnvironment or None 732 """ 733 env = self.get_compute_environment_by_arn(identifier) 734 if env is None: 735 env = self.get_compute_environment_by_name(identifier) 736 return env 737 738 def get_job_queue_by_arn(self, arn): 739 return self._job_queues.get(arn) 740 741 def get_job_queue_by_name(self, name): 742 for comp_env in self._job_queues.values(): 743 if comp_env.name == name: 744 return comp_env 745 return None 746 747 def get_job_queue(self, identifier): 748 """ 749 Get job queue by name or ARN 750 :param identifier: Name or ARN 751 :type identifier: str 752 753 :return: Job Queue or None 754 :rtype: JobQueue or None 755 """ 756 env = self.get_job_queue_by_arn(identifier) 757 if env is None: 758 env = self.get_job_queue_by_name(identifier) 759 return env 760 761 def get_job_definition_by_arn(self, arn): 762 return self._job_definitions.get(arn) 763 764 def get_job_definition_by_name(self, name): 765 latest_revision = -1 766 latest_job = None 767 for job_def in self._job_definitions.values(): 768 if job_def.name == name and job_def.revision > latest_revision: 769 latest_job = job_def 770 latest_revision = job_def.revision 771 return latest_job 772 773 def get_job_definition_by_name_revision(self, name, revision): 774 for job_def in self._job_definitions.values(): 775 if job_def.name == name and job_def.revision == int(revision): 776 return job_def 777 return None 778 779 def get_job_definition(self, identifier): 780 """ 781 Get job definitions by name or ARN 782 :param identifier: Name or ARN 783 :type identifier: str 784 785 :return: Job definition or None 786 :rtype: JobDefinition or None 787 """ 788 job_def = self.get_job_definition_by_arn(identifier) 789 if job_def is None: 790 if ":" in identifier: 791 job_def = self.get_job_definition_by_name_revision( 792 *identifier.split(":", 1) 793 ) 794 else: 795 job_def = self.get_job_definition_by_name(identifier) 796 return job_def 797 798 def get_job_definitions(self, identifier): 799 """ 800 Get job definitions by name or ARN 801 :param identifier: Name or ARN 802 :type identifier: str 803 804 :return: Job definition or None 805 :rtype: list of JobDefinition 806 """ 807 result = [] 808 env = self.get_job_definition_by_arn(identifier) 809 if env is not None: 810 result.append(env) 811 else: 812 for value in self._job_definitions.values(): 813 if value.name == identifier: 814 result.append(value) 815 816 return result 817 818 def get_job_by_id(self, identifier): 819 """ 820 Get job by id 821 :param identifier: Job ID 822 :type identifier: str 823 824 :return: Job 825 :rtype: Job 826 """ 827 try: 828 return self._jobs[identifier] 829 except KeyError: 830 return None 831 832 def describe_compute_environments( 833 self, environments=None, max_results=None, next_token=None 834 ): 835 envs = set() 836 if environments is not None: 837 envs = set(environments) 838 839 result = [] 840 for arn, environment in self._compute_environments.items(): 841 # Filter shortcut 842 if len(envs) > 0 and arn not in envs and environment.name not in envs: 843 continue 844 845 json_part = { 846 "computeEnvironmentArn": arn, 847 "computeEnvironmentName": environment.name, 848 "ecsClusterArn": environment.ecs_arn, 849 "serviceRole": environment.service_role, 850 "state": environment.state, 851 "type": environment.env_type, 852 "status": "VALID", 853 "statusReason": "Compute environment is available", 854 } 855 if environment.env_type == "MANAGED": 856 json_part["computeResources"] = environment.compute_resources 857 858 result.append(json_part) 859 860 return result 861 862 def create_compute_environment( 863 self, compute_environment_name, _type, state, compute_resources, service_role 864 ): 865 # Validate 866 if COMPUTE_ENVIRONMENT_NAME_REGEX.match(compute_environment_name) is None: 867 raise InvalidParameterValueException( 868 "Compute environment name does not match ^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$" 869 ) 870 871 if self.get_compute_environment_by_name(compute_environment_name) is not None: 872 raise InvalidParameterValueException( 873 "A compute environment already exists with the name {0}".format( 874 compute_environment_name 875 ) 876 ) 877 878 # Look for IAM role 879 try: 880 self.iam_backend.get_role_by_arn(service_role) 881 except IAMNotFoundException: 882 raise InvalidParameterValueException( 883 "Could not find IAM role {0}".format(service_role) 884 ) 885 886 if _type not in ("MANAGED", "UNMANAGED"): 887 raise InvalidParameterValueException( 888 "type {0} must be one of MANAGED | UNMANAGED".format(service_role) 889 ) 890 891 if state is not None and state not in ("ENABLED", "DISABLED"): 892 raise InvalidParameterValueException( 893 "state {0} must be one of ENABLED | DISABLED".format(state) 894 ) 895 896 if compute_resources is None and _type == "MANAGED": 897 raise InvalidParameterValueException( 898 "computeResources must be specified when creating a {0} environment".format( 899 state 900 ) 901 ) 902 elif compute_resources is not None: 903 self._validate_compute_resources(compute_resources) 904 905 # By here, all values except SPOT ones have been validated 906 new_comp_env = ComputeEnvironment( 907 compute_environment_name, 908 _type, 909 state, 910 compute_resources, 911 service_role, 912 region_name=self.region_name, 913 ) 914 self._compute_environments[new_comp_env.arn] = new_comp_env 915 916 # Ok by this point, everything is legit, so if its Managed then start some instances 917 if _type == "MANAGED" and "FARGATE" not in compute_resources["type"]: 918 cpus = int( 919 compute_resources.get("desiredvCpus", compute_resources["minvCpus"]) 920 ) 921 instance_types = compute_resources["instanceTypes"] 922 needed_instance_types = self.find_min_instances_to_meet_vcpus( 923 instance_types, cpus 924 ) 925 # Create instances 926 927 # Will loop over and over so we get decent subnet coverage 928 subnet_cycle = cycle(compute_resources["subnets"]) 929 930 for instance_type in needed_instance_types: 931 reservation = self.ec2_backend.add_instances( 932 image_id="ami-03cf127a", # Todo import AMIs 933 count=1, 934 user_data=None, 935 security_group_names=[], 936 instance_type=instance_type, 937 region_name=self.region_name, 938 subnet_id=next(subnet_cycle), 939 key_name=compute_resources.get("ec2KeyPair", "AWS_OWNED"), 940 security_group_ids=compute_resources["securityGroupIds"], 941 ) 942 943 new_comp_env.add_instance(reservation.instances[0]) 944 945 # Create ECS cluster 946 # Should be of format P2OnDemand_Batch_UUID 947 cluster_name = "OnDemand_Batch_" + str(uuid.uuid4()) 948 ecs_cluster = self.ecs_backend.create_cluster(cluster_name) 949 new_comp_env.set_ecs(ecs_cluster.arn, cluster_name) 950 951 return compute_environment_name, new_comp_env.arn 952 953 def _validate_compute_resources(self, cr): 954 """ 955 Checks contents of sub dictionary for managed clusters 956 957 :param cr: computeResources 958 :type cr: dict 959 """ 960 if int(cr["maxvCpus"]) < 0: 961 raise InvalidParameterValueException("maxVCpus must be positive") 962 if "FARGATE" not in cr["type"]: 963 # Most parameters are not applicable to jobs that are running on Fargate resources: 964 # non exhaustive list: minvCpus, instanceTypes, imageId, ec2KeyPair, instanceRole, tags 965 for profile in self.iam_backend.get_instance_profiles(): 966 if profile.arn == cr["instanceRole"]: 967 break 968 else: 969 raise InvalidParameterValueException( 970 "could not find instanceRole {0}".format(cr["instanceRole"]) 971 ) 972 973 if int(cr["minvCpus"]) < 0: 974 raise InvalidParameterValueException("minvCpus must be positive") 975 if int(cr["maxvCpus"]) < int(cr["minvCpus"]): 976 raise InvalidParameterValueException( 977 "maxVCpus must be greater than minvCpus" 978 ) 979 980 if len(cr["instanceTypes"]) == 0: 981 raise InvalidParameterValueException( 982 "At least 1 instance type must be provided" 983 ) 984 for instance_type in cr["instanceTypes"]: 985 if instance_type == "optimal": 986 pass # Optimal should pick from latest of current gen 987 elif instance_type not in EC2_INSTANCE_TYPES: 988 raise InvalidParameterValueException( 989 "Instance type {0} does not exist".format(instance_type) 990 ) 991 992 for sec_id in cr["securityGroupIds"]: 993 if self.ec2_backend.get_security_group_from_id(sec_id) is None: 994 raise InvalidParameterValueException( 995 "security group {0} does not exist".format(sec_id) 996 ) 997 if len(cr["securityGroupIds"]) == 0: 998 raise InvalidParameterValueException( 999 "At least 1 security group must be provided" 1000 ) 1001 1002 for subnet_id in cr["subnets"]: 1003 try: 1004 self.ec2_backend.get_subnet(subnet_id) 1005 except InvalidSubnetIdError: 1006 raise InvalidParameterValueException( 1007 "subnet {0} does not exist".format(subnet_id) 1008 ) 1009 if len(cr["subnets"]) == 0: 1010 raise InvalidParameterValueException("At least 1 subnet must be provided") 1011 1012 if cr["type"] not in {"EC2", "SPOT", "FARGATE", "FARGATE_SPOT"}: 1013 raise InvalidParameterValueException( 1014 "computeResources.type must be either EC2 | SPOT | FARGATE | FARGATE_SPOT" 1015 ) 1016 1017 @staticmethod 1018 def find_min_instances_to_meet_vcpus(instance_types, target): 1019 """ 1020 Finds the minimum needed instances to meed a vcpu target 1021 1022 :param instance_types: Instance types, like ['t2.medium', 't2.small'] 1023 :type instance_types: list of str 1024 :param target: VCPU target 1025 :type target: float 1026 :return: List of instance types 1027 :rtype: list of str 1028 """ 1029 # vcpus = [ (vcpus, instance_type), (vcpus, instance_type), ... ] 1030 instance_vcpus = [] 1031 instances = [] 1032 1033 for instance_type in instance_types: 1034 if instance_type == "optimal": 1035 instance_type = "m4.4xlarge" 1036 1037 instance_vcpus.append( 1038 ( 1039 EC2_INSTANCE_TYPES[instance_type]["VCpuInfo"]["DefaultVCpus"], 1040 instance_type, 1041 ) 1042 ) 1043 1044 instance_vcpus = sorted(instance_vcpus, key=lambda item: item[0], reverse=True) 1045 # Loop through, 1046 # if biggest instance type smaller than target, and len(instance_types)> 1, then use biggest type 1047 # if biggest instance type bigger than target, and len(instance_types)> 1, then remove it and move on 1048 1049 # if biggest instance type bigger than target and len(instan_types) == 1 then add instance and finish 1050 # if biggest instance type smaller than target and len(instan_types) == 1 then loop adding instances until target == 0 1051 # ^^ boils down to keep adding last till target vcpus is negative 1052 # #Algorithm ;-) ... Could probably be done better with some quality lambdas 1053 while target > 0: 1054 current_vcpu, current_instance = instance_vcpus[0] 1055 1056 if len(instance_vcpus) > 1: 1057 if current_vcpu <= target: 1058 target -= current_vcpu 1059 instances.append(current_instance) 1060 else: 1061 # try next biggest instance 1062 instance_vcpus.pop(0) 1063 else: 1064 # Were on the last instance 1065 target -= current_vcpu 1066 instances.append(current_instance) 1067 1068 return instances 1069 1070 def delete_compute_environment(self, compute_environment_name): 1071 if compute_environment_name is None: 1072 raise InvalidParameterValueException("Missing computeEnvironment parameter") 1073 1074 compute_env = self.get_compute_environment(compute_environment_name) 1075 1076 if compute_env is not None: 1077 # Pop ComputeEnvironment 1078 self._compute_environments.pop(compute_env.arn) 1079 1080 # Delete ECS cluster 1081 self.ecs_backend.delete_cluster(compute_env.ecs_name) 1082 1083 if compute_env.env_type == "MANAGED": 1084 # Delete compute environment 1085 instance_ids = [instance.id for instance in compute_env.instances] 1086 if instance_ids: 1087 self.ec2_backend.terminate_instances(instance_ids) 1088 1089 def update_compute_environment( 1090 self, compute_environment_name, state, compute_resources, service_role 1091 ): 1092 # Validate 1093 compute_env = self.get_compute_environment(compute_environment_name) 1094 if compute_env is None: 1095 raise ClientException("Compute environment {0} does not exist") 1096 1097 # Look for IAM role 1098 if service_role is not None: 1099 try: 1100 role = self.iam_backend.get_role_by_arn(service_role) 1101 except IAMNotFoundException: 1102 raise InvalidParameterValueException( 1103 "Could not find IAM role {0}".format(service_role) 1104 ) 1105 1106 compute_env.service_role = role 1107 1108 if state is not None: 1109 if state not in ("ENABLED", "DISABLED"): 1110 raise InvalidParameterValueException( 1111 "state {0} must be one of ENABLED | DISABLED".format(state) 1112 ) 1113 1114 compute_env.state = state 1115 1116 if compute_resources is not None: 1117 # TODO Implement resizing of instances based on changing vCpus 1118 # compute_resources CAN contain desiredvCpus, maxvCpus, minvCpus, and can contain none of them. 1119 pass 1120 1121 return compute_env.name, compute_env.arn 1122 1123 def create_job_queue(self, queue_name, priority, state, compute_env_order): 1124 """ 1125 Create a job queue 1126 1127 :param queue_name: Queue name 1128 :type queue_name: str 1129 :param priority: Queue priority 1130 :type priority: int 1131 :param state: Queue state 1132 :type state: string 1133 :param compute_env_order: Compute environment list 1134 :type compute_env_order: list of dict 1135 :return: Tuple of Name, ARN 1136 :rtype: tuple of str 1137 """ 1138 for variable, var_name in ( 1139 (queue_name, "jobQueueName"), 1140 (priority, "priority"), 1141 (state, "state"), 1142 (compute_env_order, "computeEnvironmentOrder"), 1143 ): 1144 if variable is None: 1145 raise ClientException("{0} must be provided".format(var_name)) 1146 1147 if state not in ("ENABLED", "DISABLED"): 1148 raise ClientException( 1149 "state {0} must be one of ENABLED | DISABLED".format(state) 1150 ) 1151 if self.get_job_queue_by_name(queue_name) is not None: 1152 raise ClientException("Job queue {0} already exists".format(queue_name)) 1153 1154 if len(compute_env_order) == 0: 1155 raise ClientException("At least 1 compute environment must be provided") 1156 try: 1157 # orders and extracts computeEnvironment names 1158 ordered_compute_environments = [ 1159 item["computeEnvironment"] 1160 for item in sorted(compute_env_order, key=lambda x: x["order"]) 1161 ] 1162 env_objects = [] 1163 # Check each ARN exists, then make a list of compute env's 1164 for arn in ordered_compute_environments: 1165 env = self.get_compute_environment_by_arn(arn) 1166 if env is None: 1167 raise ClientException( 1168 "Compute environment {0} does not exist".format(arn) 1169 ) 1170 env_objects.append(env) 1171 except Exception: 1172 raise ClientException("computeEnvironmentOrder is malformed") 1173 1174 # Create new Job Queue 1175 queue = JobQueue( 1176 queue_name, 1177 priority, 1178 state, 1179 env_objects, 1180 compute_env_order, 1181 self.region_name, 1182 ) 1183 self._job_queues[queue.arn] = queue 1184 1185 return queue_name, queue.arn 1186 1187 def describe_job_queues(self, job_queues=None, max_results=None, next_token=None): 1188 envs = set() 1189 if job_queues is not None: 1190 envs = set(job_queues) 1191 1192 result = [] 1193 for arn, job_queue in self._job_queues.items(): 1194 # Filter shortcut 1195 if len(envs) > 0 and arn not in envs and job_queue.name not in envs: 1196 continue 1197 1198 result.append(job_queue.describe()) 1199 1200 return result 1201 1202 def update_job_queue(self, queue_name, priority, state, compute_env_order): 1203 """ 1204 Update a job queue 1205 1206 :param queue_name: Queue name 1207 :type queue_name: str 1208 :param priority: Queue priority 1209 :type priority: int 1210 :param state: Queue state 1211 :type state: string 1212 :param compute_env_order: Compute environment list 1213 :type compute_env_order: list of dict 1214 :return: Tuple of Name, ARN 1215 :rtype: tuple of str 1216 """ 1217 if queue_name is None: 1218 raise ClientException("jobQueueName must be provided") 1219 1220 job_queue = self.get_job_queue(queue_name) 1221 if job_queue is None: 1222 raise ClientException("Job queue {0} does not exist".format(queue_name)) 1223 1224 if state is not None: 1225 if state not in ("ENABLED", "DISABLED"): 1226 raise ClientException( 1227 "state {0} must be one of ENABLED | DISABLED".format(state) 1228 ) 1229 1230 job_queue.state = state 1231 1232 if compute_env_order is not None: 1233 if len(compute_env_order) == 0: 1234 raise ClientException("At least 1 compute environment must be provided") 1235 try: 1236 # orders and extracts computeEnvironment names 1237 ordered_compute_environments = [ 1238 item["computeEnvironment"] 1239 for item in sorted(compute_env_order, key=lambda x: x["order"]) 1240 ] 1241 env_objects = [] 1242 # Check each ARN exists, then make a list of compute env's 1243 for arn in ordered_compute_environments: 1244 env = self.get_compute_environment_by_arn(arn) 1245 if env is None: 1246 raise ClientException( 1247 "Compute environment {0} does not exist".format(arn) 1248 ) 1249 env_objects.append(env) 1250 except Exception: 1251 raise ClientException("computeEnvironmentOrder is malformed") 1252 1253 job_queue.env_order_json = compute_env_order 1254 job_queue.environments = env_objects 1255 1256 if priority is not None: 1257 job_queue.priority = priority 1258 1259 return queue_name, job_queue.arn 1260 1261 def delete_job_queue(self, queue_name): 1262 job_queue = self.get_job_queue(queue_name) 1263 1264 if job_queue is not None: 1265 del self._job_queues[job_queue.arn] 1266 1267 def register_job_definition( 1268 self, def_name, parameters, _type, tags, retry_strategy, container_properties 1269 ): 1270 if def_name is None: 1271 raise ClientException("jobDefinitionName must be provided") 1272 1273 job_def = self.get_job_definition_by_name(def_name) 1274 if retry_strategy is not None: 1275 try: 1276 retry_strategy = retry_strategy["attempts"] 1277 except Exception: 1278 raise ClientException("retryStrategy is malformed") 1279 if job_def is None: 1280 if not tags: 1281 tags = {} 1282 job_def = JobDefinition( 1283 def_name, 1284 parameters, 1285 _type, 1286 container_properties, 1287 tags=tags, 1288 region_name=self.region_name, 1289 retry_strategy=retry_strategy, 1290 ) 1291 else: 1292 # Make new jobdef 1293 job_def = job_def.update( 1294 parameters, _type, container_properties, retry_strategy 1295 ) 1296 1297 self._job_definitions[job_def.arn] = job_def 1298 1299 return def_name, job_def.arn, job_def.revision 1300 1301 def deregister_job_definition(self, def_name): 1302 job_def = self.get_job_definition_by_arn(def_name) 1303 if job_def is None and ":" in def_name: 1304 name, revision = def_name.split(":", 1) 1305 job_def = self.get_job_definition_by_name_revision(name, revision) 1306 1307 if job_def is not None: 1308 del self._job_definitions[job_def.arn] 1309 1310 def describe_job_definitions( 1311 self, 1312 job_def_name=None, 1313 job_def_list=None, 1314 status=None, 1315 max_results=None, 1316 next_token=None, 1317 ): 1318 jobs = [] 1319 1320 # As a job name can reference multiple revisions, we get a list of them 1321 if job_def_name is not None: 1322 job_def = self.get_job_definitions(job_def_name) 1323 if job_def is not None: 1324 jobs.extend(job_def) 1325 elif job_def_list is not None: 1326 for job in job_def_list: 1327 job_def = self.get_job_definitions(job) 1328 if job_def is not None: 1329 jobs.extend(job_def) 1330 else: 1331 jobs.extend(self._job_definitions.values()) 1332 1333 # Got all the job defs were after, filter then by status 1334 if status is not None: 1335 return [job for job in jobs if job.status == status] 1336 for job in jobs: 1337 job.describe() 1338 return jobs 1339 1340 def submit_job( 1341 self, 1342 job_name, 1343 job_def_id, 1344 job_queue, 1345 parameters=None, 1346 retries=None, 1347 depends_on=None, 1348 container_overrides=None, 1349 ): 1350 # TODO parameters, retries (which is a dict raw from request), job dependencies and container overrides are ignored for now 1351 1352 # Look for job definition 1353 job_def = self.get_job_definition(job_def_id) 1354 if job_def is None: 1355 raise ClientException( 1356 "Job definition {0} does not exist".format(job_def_id) 1357 ) 1358 1359 queue = self.get_job_queue(job_queue) 1360 if queue is None: 1361 raise ClientException("Job queue {0} does not exist".format(job_queue)) 1362 1363 job = Job( 1364 job_name, 1365 job_def, 1366 queue, 1367 log_backend=self.logs_backend, 1368 container_overrides=container_overrides, 1369 depends_on=depends_on, 1370 all_jobs=self._jobs, 1371 ) 1372 self._jobs[job.job_id] = job 1373 1374 # Here comes the fun 1375 job.start() 1376 1377 return job_name, job.job_id 1378 1379 def describe_jobs(self, jobs): 1380 job_filter = set() 1381 if jobs is not None: 1382 job_filter = set(jobs) 1383 1384 result = [] 1385 for key, job in self._jobs.items(): 1386 if len(job_filter) > 0 and key not in job_filter: 1387 continue 1388 1389 result.append(job.describe()) 1390 1391 return result 1392 1393 def list_jobs(self, job_queue, job_status=None, max_results=None, next_token=None): 1394 jobs = [] 1395 1396 job_queue = self.get_job_queue(job_queue) 1397 if job_queue is None: 1398 raise ClientException("Job queue {0} does not exist".format(job_queue)) 1399 1400 if job_status is not None and job_status not in ( 1401 "SUBMITTED", 1402 "PENDING", 1403 "RUNNABLE", 1404 "STARTING", 1405 "RUNNING", 1406 "SUCCEEDED", 1407 "FAILED", 1408 ): 1409 raise ClientException( 1410 "Job status is not one of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED" 1411 ) 1412 1413 for job in job_queue.jobs: 1414 if job_status is not None and job.job_state != job_status: 1415 continue 1416 1417 jobs.append(job) 1418 1419 return jobs 1420 1421 def cancel_job(self, job_id, reason): 1422 job = self.get_job_by_id(job_id) 1423 if job.job_state in ["SUBMITTED", "PENDING", "RUNNABLE"]: 1424 job.terminate(reason) 1425 # No-Op for jobs that have already started - user has to explicitly terminate those 1426 1427 def terminate_job(self, job_id, reason): 1428 if job_id is None: 1429 raise ClientException("Job ID does not exist") 1430 if reason is None: 1431 raise ClientException("Reason does not exist") 1432 1433 job = self.get_job_by_id(job_id) 1434 if job is None: 1435 raise ClientException("Job not found") 1436 1437 job.terminate(reason) 1438 1439 1440batch_backends = {} 1441for region in Session().get_available_regions("batch"): 1442 batch_backends[region] = BatchBackend(region) 1443for region in Session().get_available_regions("batch", partition_name="aws-us-gov"): 1444 batch_backends[region] = BatchBackend(region) 1445for region in Session().get_available_regions("batch", partition_name="aws-cn"): 1446 batch_backends[region] = BatchBackend(region) 1447