1import datetime
2from boto3 import Session
3
4from collections import OrderedDict
5from moto.core import BaseBackend, BaseModel, CloudFormationModel
6from .utils import get_random_pipeline_id, remove_capitalization_of_dict_keys
7
8
9class PipelineObject(BaseModel):
10    def __init__(self, object_id, name, fields):
11        self.object_id = object_id
12        self.name = name
13        self.fields = fields
14
15    def to_json(self):
16        return {"fields": self.fields, "id": self.object_id, "name": self.name}
17
18
19class Pipeline(CloudFormationModel):
20    def __init__(self, name, unique_id, **kwargs):
21        self.name = name
22        self.unique_id = unique_id
23        self.description = kwargs.get("description", "")
24        self.pipeline_id = get_random_pipeline_id()
25        self.creation_time = datetime.datetime.utcnow()
26        self.objects = []
27        self.status = "PENDING"
28        self.tags = kwargs.get("tags", [])
29
30    @property
31    def physical_resource_id(self):
32        return self.pipeline_id
33
34    def to_meta_json(self):
35        return {"id": self.pipeline_id, "name": self.name}
36
37    def to_json(self):
38        return {
39            "description": self.description,
40            "fields": [
41                {"key": "@pipelineState", "stringValue": self.status},
42                {"key": "description", "stringValue": self.description},
43                {"key": "name", "stringValue": self.name},
44                {
45                    "key": "@creationTime",
46                    "stringValue": datetime.datetime.strftime(
47                        self.creation_time, "%Y-%m-%dT%H-%M-%S"
48                    ),
49                },
50                {"key": "@id", "stringValue": self.pipeline_id},
51                {"key": "@sphere", "stringValue": "PIPELINE"},
52                {"key": "@version", "stringValue": "1"},
53                {"key": "@userId", "stringValue": "924374875933"},
54                {"key": "@accountId", "stringValue": "924374875933"},
55                {"key": "uniqueId", "stringValue": self.unique_id},
56            ],
57            "name": self.name,
58            "pipelineId": self.pipeline_id,
59            "tags": self.tags,
60        }
61
62    def set_pipeline_objects(self, pipeline_objects):
63        self.objects = [
64            PipelineObject(
65                pipeline_object["id"],
66                pipeline_object["name"],
67                pipeline_object["fields"],
68            )
69            for pipeline_object in remove_capitalization_of_dict_keys(pipeline_objects)
70        ]
71
72    def activate(self):
73        self.status = "SCHEDULED"
74
75    @staticmethod
76    def cloudformation_name_type():
77        return "Name"
78
79    @staticmethod
80    def cloudformation_type():
81        # https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-datapipeline-pipeline.html
82        return "AWS::DataPipeline::Pipeline"
83
84    @classmethod
85    def create_from_cloudformation_json(
86        cls, resource_name, cloudformation_json, region_name, **kwargs
87    ):
88        datapipeline_backend = datapipeline_backends[region_name]
89        properties = cloudformation_json["Properties"]
90
91        cloudformation_unique_id = "cf-" + resource_name
92        pipeline = datapipeline_backend.create_pipeline(
93            resource_name, cloudformation_unique_id
94        )
95        datapipeline_backend.put_pipeline_definition(
96            pipeline.pipeline_id, properties["PipelineObjects"]
97        )
98
99        if properties["Activate"]:
100            pipeline.activate()
101        return pipeline
102
103
104class DataPipelineBackend(BaseBackend):
105    def __init__(self):
106        self.pipelines = OrderedDict()
107
108    def create_pipeline(self, name, unique_id, **kwargs):
109        pipeline = Pipeline(name, unique_id, **kwargs)
110        self.pipelines[pipeline.pipeline_id] = pipeline
111        return pipeline
112
113    def list_pipelines(self):
114        return self.pipelines.values()
115
116    def describe_pipelines(self, pipeline_ids):
117        pipelines = [
118            pipeline
119            for pipeline in self.pipelines.values()
120            if pipeline.pipeline_id in pipeline_ids
121        ]
122        return pipelines
123
124    def get_pipeline(self, pipeline_id):
125        return self.pipelines[pipeline_id]
126
127    def delete_pipeline(self, pipeline_id):
128        self.pipelines.pop(pipeline_id, None)
129
130    def put_pipeline_definition(self, pipeline_id, pipeline_objects):
131        pipeline = self.get_pipeline(pipeline_id)
132        pipeline.set_pipeline_objects(pipeline_objects)
133
134    def get_pipeline_definition(self, pipeline_id):
135        pipeline = self.get_pipeline(pipeline_id)
136        return pipeline.objects
137
138    def describe_objects(self, object_ids, pipeline_id):
139        pipeline = self.get_pipeline(pipeline_id)
140        pipeline_objects = [
141            pipeline_object
142            for pipeline_object in pipeline.objects
143            if pipeline_object.object_id in object_ids
144        ]
145        return pipeline_objects
146
147    def activate_pipeline(self, pipeline_id):
148        pipeline = self.get_pipeline(pipeline_id)
149        pipeline.activate()
150
151
152datapipeline_backends = {}
153for region in Session().get_available_regions("datapipeline"):
154    datapipeline_backends[region] = DataPipelineBackend()
155for region in Session().get_available_regions(
156    "datapipeline", partition_name="aws-us-gov"
157):
158    datapipeline_backends[region] = DataPipelineBackend()
159for region in Session().get_available_regions("datapipeline", partition_name="aws-cn"):
160    datapipeline_backends[region] = DataPipelineBackend(region)
161