1"""EMRContainersBackend class with methods for supported APIs.""" 2import re 3from datetime import datetime 4from boto3 import Session 5 6from moto.core import BaseBackend, BaseModel, ACCOUNT_ID 7from moto.core.utils import iso_8601_datetime_without_milliseconds 8 9from .utils import random_cluster_id, random_job_id, get_partition, paginated_list 10from .exceptions import ResourceNotFoundException 11 12from ..config.exceptions import ValidationException 13 14VIRTUAL_CLUSTER_ARN_TEMPLATE = ( 15 "arn:{partition}:emr-containers:{region}:" 16 + str(ACCOUNT_ID) 17 + ":/virtualclusters/{virtual_cluster_id}" 18) 19 20JOB_ARN_TEMPLATE = ( 21 "arn:{partition}:emr-containers:{region}:" 22 + str(ACCOUNT_ID) 23 + ":/virtualclusters/{virtual_cluster_id}/jobruns/{job_id}" 24) 25 26# Defaults used for creating a Virtual cluster 27VIRTUAL_CLUSTER_STATUS = "RUNNING" 28JOB_STATUS = "RUNNING" 29 30 31class FakeCluster(BaseModel): 32 def __init__( 33 self, 34 name, 35 container_provider, 36 client_token, 37 region_name, 38 aws_partition, 39 tags=None, 40 virtual_cluster_id=None, 41 ): 42 self.id = virtual_cluster_id or random_cluster_id() 43 44 self.name = name 45 self.client_token = client_token 46 self.arn = VIRTUAL_CLUSTER_ARN_TEMPLATE.format( 47 partition=aws_partition, region=region_name, virtual_cluster_id=self.id 48 ) 49 self.state = VIRTUAL_CLUSTER_STATUS 50 self.container_provider = container_provider 51 self.container_provider_id = container_provider["id"] 52 self.namespace = container_provider["info"]["eksInfo"]["namespace"] 53 self.creation_date = iso_8601_datetime_without_milliseconds( 54 datetime.today().replace(hour=0, minute=0, second=0, microsecond=0) 55 ) 56 self.tags = tags 57 58 def __iter__(self): 59 yield "id", self.id 60 yield "name", self.name 61 yield "arn", self.arn 62 yield "state", self.state 63 yield "containerProvider", self.container_provider 64 yield "createdAt", self.creation_date 65 yield "tags", self.tags 66 67 def to_dict(self): 68 # Format for summary https://docs.aws.amazon.com/emr-on-eks/latest/APIReference/API_DescribeVirtualCluster.html 69 # (response syntax section) 70 return { 71 "id": self.id, 72 "name": self.name, 73 "arn": self.arn, 74 "state": self.state, 75 "containerProvider": self.container_provider, 76 "createdAt": self.creation_date, 77 "tags": self.tags, 78 } 79 80 81class FakeJob(BaseModel): 82 def __init__( 83 self, 84 name, 85 virtual_cluster_id, 86 client_token, 87 execution_role_arn, 88 release_label, 89 job_driver, 90 configuration_overrides, 91 region_name, 92 aws_partition, 93 tags, 94 ): 95 self.id = random_job_id() 96 self.name = name 97 self.virtual_cluster_id = virtual_cluster_id 98 self.arn = JOB_ARN_TEMPLATE.format( 99 partition=aws_partition, 100 region=region_name, 101 virtual_cluster_id=self.virtual_cluster_id, 102 job_id=self.id, 103 ) 104 self.state = JOB_STATUS 105 self.client_token = client_token 106 self.execution_role_arn = execution_role_arn 107 self.release_label = release_label 108 self.job_driver = job_driver 109 self.configuration_overrides = configuration_overrides 110 self.created_at = iso_8601_datetime_without_milliseconds( 111 datetime.today().replace(hour=0, minute=0, second=0, microsecond=0) 112 ) 113 self.created_by = None 114 self.finished_at = None 115 self.state_details = None 116 self.failure_reason = None 117 self.tags = tags 118 119 def __iter__(self): 120 yield "id", self.id 121 yield "name", self.name 122 yield "virtualClusterId", self.virtual_cluster_id 123 yield "arn", self.arn 124 yield "state", self.state 125 yield "clientToken", self.client_token 126 yield "executionRoleArn", self.execution_role_arn 127 yield "releaseLabel", self.release_label 128 yield "configurationOverrides", self.release_label 129 yield "jobDriver", self.job_driver 130 yield "createdAt", self.created_at 131 yield "createdBy", self.created_by 132 yield "finishedAt", self.finished_at 133 yield "stateDetails", self.state_details 134 yield "failureReason", self.failure_reason 135 yield "tags", self.tags 136 137 def to_dict(self): 138 # Format for summary https://docs.aws.amazon.com/emr-on-eks/latest/APIReference/API_DescribeJobRun.html 139 # (response syntax section) 140 return { 141 "id": self.id, 142 "name": self.name, 143 "virtualClusterId": self.virtual_cluster_id, 144 "arn": self.arn, 145 "state": self.state, 146 "clientToken": self.client_token, 147 "executionRoleArn": self.execution_role_arn, 148 "releaseLabel": self.release_label, 149 "configurationOverrides": self.configuration_overrides, 150 "jobDriver": self.job_driver, 151 "createdAt": self.created_at, 152 "createdBy": self.created_by, 153 "finishedAt": self.finished_at, 154 "stateDetails": self.state_details, 155 "failureReason": self.failure_reason, 156 "tags": self.tags, 157 } 158 159 160class EMRContainersBackend(BaseBackend): 161 """Implementation of EMRContainers APIs.""" 162 163 def __init__(self, region_name=None): 164 super(EMRContainersBackend, self).__init__() 165 self.virtual_clusters = dict() 166 self.virtual_cluster_count = 0 167 self.jobs = dict() 168 self.job_count = 0 169 self.region_name = region_name 170 self.partition = get_partition(region_name) 171 172 def reset(self): 173 """Re-initialize all attributes for this instance.""" 174 region_name = self.region_name 175 self.__dict__ = {} 176 self.__init__(region_name) 177 178 def create_virtual_cluster(self, name, container_provider, client_token, tags=None): 179 occupied_namespaces = [ 180 virtual_cluster.namespace 181 for virtual_cluster in self.virtual_clusters.values() 182 ] 183 184 if container_provider["info"]["eksInfo"]["namespace"] in occupied_namespaces: 185 raise ValidationException( 186 "A virtual cluster already exists in the given namespace" 187 ) 188 189 virtual_cluster = FakeCluster( 190 name=name, 191 container_provider=container_provider, 192 client_token=client_token, 193 tags=tags, 194 region_name=self.region_name, 195 aws_partition=self.partition, 196 ) 197 198 self.virtual_clusters[virtual_cluster.id] = virtual_cluster 199 self.virtual_cluster_count += 1 200 return virtual_cluster 201 202 def delete_virtual_cluster(self, id): 203 if id not in self.virtual_clusters: 204 raise ValidationException("VirtualCluster does not exist") 205 206 self.virtual_clusters[id].state = "TERMINATED" 207 return self.virtual_clusters[id] 208 209 def describe_virtual_cluster(self, id): 210 if id not in self.virtual_clusters: 211 raise ValidationException(f"Virtual cluster {id} doesn't exist.") 212 213 return self.virtual_clusters[id].to_dict() 214 215 def list_virtual_clusters( 216 self, 217 container_provider_id, 218 container_provider_type, 219 created_after, 220 created_before, 221 states, 222 max_results, 223 next_token, 224 ): 225 virtual_clusters = [ 226 virtual_cluster.to_dict() 227 for virtual_cluster in self.virtual_clusters.values() 228 ] 229 230 if container_provider_id: 231 virtual_clusters = [ 232 virtual_cluster 233 for virtual_cluster in virtual_clusters 234 if virtual_cluster["containerProvider"]["id"] == container_provider_id 235 ] 236 237 if container_provider_type: 238 virtual_clusters = [ 239 virtual_cluster 240 for virtual_cluster in virtual_clusters 241 if virtual_cluster["containerProvider"]["type"] 242 == container_provider_type 243 ] 244 245 if created_after: 246 virtual_clusters = [ 247 virtual_cluster 248 for virtual_cluster in virtual_clusters 249 if virtual_cluster["createdAt"] >= created_after 250 ] 251 252 if created_before: 253 virtual_clusters = [ 254 virtual_cluster 255 for virtual_cluster in virtual_clusters 256 if virtual_cluster["createdAt"] <= created_before 257 ] 258 259 if states: 260 virtual_clusters = [ 261 virtual_cluster 262 for virtual_cluster in virtual_clusters 263 if virtual_cluster["state"] in states 264 ] 265 sort_key = "name" 266 return paginated_list(virtual_clusters, sort_key, max_results, next_token) 267 268 def start_job_run( 269 self, 270 name, 271 virtual_cluster_id, 272 client_token, 273 execution_role_arn, 274 release_label, 275 job_driver, 276 configuration_overrides, 277 tags, 278 ): 279 280 if virtual_cluster_id not in self.virtual_clusters.keys(): 281 raise ResourceNotFoundException( 282 f"Virtual cluster {virtual_cluster_id} doesn't exist." 283 ) 284 285 if not re.match( 286 r"emr-[0-9]{1}\.[0-9]{1,2}\.0-(latest|[0-9]{8})", release_label 287 ): 288 raise ResourceNotFoundException(f"Release {release_label} doesn't exist.") 289 290 job = FakeJob( 291 name=name, 292 virtual_cluster_id=virtual_cluster_id, 293 client_token=client_token, 294 execution_role_arn=execution_role_arn, 295 release_label=release_label, 296 job_driver=job_driver, 297 configuration_overrides=configuration_overrides, 298 tags=tags, 299 region_name=self.region_name, 300 aws_partition=self.partition, 301 ) 302 303 self.jobs[job.id] = job 304 self.job_count += 1 305 return job 306 307 def cancel_job_run(self, id, virtual_cluster_id): 308 309 if not re.match(r"[a-z,A-Z,0-9]{19}", id): 310 raise ValidationException("Invalid job run short id") 311 312 if id not in self.jobs.keys(): 313 raise ResourceNotFoundException(f"Job run {id} doesn't exist.") 314 315 if virtual_cluster_id != self.jobs[id].virtual_cluster_id: 316 raise ResourceNotFoundException(f"Job run {id} doesn't exist.") 317 318 if self.jobs[id].state in [ 319 "FAILED", 320 "CANCELLED", 321 "CANCEL_PENDING", 322 "COMPLETED", 323 ]: 324 raise ValidationException(f"Job run {id} is not in a cancellable state") 325 326 job = self.jobs[id] 327 job.state = "CANCELLED" 328 job.finished_at = iso_8601_datetime_without_milliseconds( 329 datetime.today().replace(hour=0, minute=1, second=0, microsecond=0) 330 ) 331 job.state_details = "JobRun CANCELLED successfully." 332 333 return job 334 335 def list_job_runs( 336 self, 337 virtual_cluster_id, 338 created_before, 339 created_after, 340 name, 341 states, 342 max_results, 343 next_token, 344 ): 345 jobs = [job.to_dict() for job in self.jobs.values()] 346 347 jobs = [job for job in jobs if job["virtualClusterId"] == virtual_cluster_id] 348 349 if created_after: 350 jobs = [job for job in jobs if job["createdAt"] >= created_after] 351 352 if created_before: 353 jobs = [job for job in jobs if job["createdAt"] <= created_before] 354 355 if states: 356 jobs = [job for job in jobs if job["state"] in states] 357 358 if name: 359 jobs = [job for job in jobs if job["name"] in name] 360 361 sort_key = "id" 362 return paginated_list(jobs, sort_key, max_results, next_token) 363 364 def describe_job_run(self, id, virtual_cluster_id): 365 if not re.match(r"[a-z,A-Z,0-9]{19}", id): 366 raise ValidationException("Invalid job run short id") 367 368 if id not in self.jobs.keys(): 369 raise ResourceNotFoundException(f"Job run {id} doesn't exist.") 370 371 if virtual_cluster_id != self.jobs[id].virtual_cluster_id: 372 raise ResourceNotFoundException(f"Job run {id} doesn't exist.") 373 374 return self.jobs[id].to_dict() 375 376 377emrcontainers_backends = {} 378for available_region in Session().get_available_regions("emr-containers"): 379 emrcontainers_backends[available_region] = EMRContainersBackend(available_region) 380for available_region in Session().get_available_regions( 381 "emr-containers", partition_name="aws-us-gov" 382): 383 emrcontainers_backends[available_region] = EMRContainersBackend(available_region) 384for available_region in Session().get_available_regions( 385 "emr-containers", partition_name="aws-cn" 386): 387 emrcontainers_backends[available_region] = EMRContainersBackend(available_region) 388