1import datetime
2import boto3
3from botocore.exceptions import ClientError
4import sure  # noqa # pylint: disable=unused-import
5
6from moto import mock_sagemaker
7from moto.sts.models import ACCOUNT_ID
8import pytest
9
10TEST_REGION_NAME = "us-east-1"
11FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID)
12GENERIC_TAGS_PARAM = [
13    {"Key": "newkey1", "Value": "newval1"},
14    {"Key": "newkey2", "Value": "newval2"},
15]
16
17
18@mock_sagemaker
19def test_create_endpoint_config():
20    sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
21
22    model_name = "MyModel"
23    production_variants = [
24        {
25            "VariantName": "MyProductionVariant",
26            "ModelName": model_name,
27            "InitialInstanceCount": 1,
28            "InstanceType": "ml.t2.medium",
29        },
30    ]
31
32    endpoint_config_name = "MyEndpointConfig"
33    with pytest.raises(ClientError) as e:
34        sagemaker.create_endpoint_config(
35            EndpointConfigName=endpoint_config_name,
36            ProductionVariants=production_variants,
37        )
38    assert e.value.response["Error"]["Message"].startswith("Could not find model")
39
40    _create_model(sagemaker, model_name)
41    resp = sagemaker.create_endpoint_config(
42        EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
43    )
44    resp["EndpointConfigArn"].should.match(
45        r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
46    )
47
48    resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
49    resp["EndpointConfigArn"].should.match(
50        r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
51    )
52    resp["EndpointConfigName"].should.equal(endpoint_config_name)
53    resp["ProductionVariants"].should.equal(production_variants)
54
55
56@mock_sagemaker
57def test_delete_endpoint_config():
58    sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
59
60    model_name = "MyModel"
61    _create_model(sagemaker, model_name)
62
63    endpoint_config_name = "MyEndpointConfig"
64    production_variants = [
65        {
66            "VariantName": "MyProductionVariant",
67            "ModelName": model_name,
68            "InitialInstanceCount": 1,
69            "InstanceType": "ml.t2.medium",
70        },
71    ]
72
73    resp = sagemaker.create_endpoint_config(
74        EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
75    )
76    resp["EndpointConfigArn"].should.match(
77        r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
78    )
79
80    resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
81    resp["EndpointConfigArn"].should.match(
82        r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
83    )
84
85    resp = sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
86    with pytest.raises(ClientError) as e:
87        sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
88    assert e.value.response["Error"]["Message"].startswith(
89        "Could not find endpoint configuration"
90    )
91
92    with pytest.raises(ClientError) as e:
93        sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
94    assert e.value.response["Error"]["Message"].startswith(
95        "Could not find endpoint configuration"
96    )
97
98
99@mock_sagemaker
100def test_create_endpoint_invalid_instance_type():
101    sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
102
103    model_name = "MyModel"
104    _create_model(sagemaker, model_name)
105
106    instance_type = "InvalidInstanceType"
107    production_variants = [
108        {
109            "VariantName": "MyProductionVariant",
110            "ModelName": model_name,
111            "InitialInstanceCount": 1,
112            "InstanceType": instance_type,
113        },
114    ]
115
116    endpoint_config_name = "MyEndpointConfig"
117    with pytest.raises(ClientError) as e:
118        sagemaker.create_endpoint_config(
119            EndpointConfigName=endpoint_config_name,
120            ProductionVariants=production_variants,
121        )
122    assert e.value.response["Error"]["Code"] == "ValidationException"
123    expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format(
124        instance_type
125    )
126    assert expected_message in e.value.response["Error"]["Message"]
127
128
129@mock_sagemaker
130def test_create_endpoint():
131    sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
132
133    endpoint_name = "MyEndpoint"
134    with pytest.raises(ClientError) as e:
135        sagemaker.create_endpoint(
136            EndpointName=endpoint_name, EndpointConfigName="NonexistentEndpointConfig"
137        )
138    assert e.value.response["Error"]["Message"].startswith(
139        "Could not find endpoint configuration"
140    )
141
142    model_name = "MyModel"
143    _create_model(sagemaker, model_name)
144
145    endpoint_config_name = "MyEndpointConfig"
146    _create_endpoint_config(sagemaker, endpoint_config_name, model_name)
147
148    resp = sagemaker.create_endpoint(
149        EndpointName=endpoint_name,
150        EndpointConfigName=endpoint_config_name,
151        Tags=GENERIC_TAGS_PARAM,
152    )
153    resp["EndpointArn"].should.match(
154        r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
155    )
156
157    resp = sagemaker.describe_endpoint(EndpointName=endpoint_name)
158    resp["EndpointArn"].should.match(
159        r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
160    )
161    resp["EndpointName"].should.equal(endpoint_name)
162    resp["EndpointConfigName"].should.equal(endpoint_config_name)
163    resp["EndpointStatus"].should.equal("InService")
164    assert isinstance(resp["CreationTime"], datetime.datetime)
165    assert isinstance(resp["LastModifiedTime"], datetime.datetime)
166    resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant")
167
168    resp = sagemaker.list_tags(ResourceArn=resp["EndpointArn"])
169    assert resp["Tags"] == GENERIC_TAGS_PARAM
170
171
172@mock_sagemaker
173def test_delete_endpoint():
174    sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME)
175
176    model_name = "MyModel"
177    _create_model(sagemaker, model_name)
178
179    endpoint_config_name = "MyEndpointConfig"
180    _create_endpoint_config(sagemaker, endpoint_config_name, model_name)
181
182    endpoint_name = "MyEndpoint"
183    _create_endpoint(sagemaker, endpoint_name, endpoint_config_name)
184
185    sagemaker.delete_endpoint(EndpointName=endpoint_name)
186    with pytest.raises(ClientError) as e:
187        sagemaker.describe_endpoint(EndpointName=endpoint_name)
188    assert e.value.response["Error"]["Message"].startswith("Could not find endpoint")
189
190    with pytest.raises(ClientError) as e:
191        sagemaker.delete_endpoint(EndpointName=endpoint_name)
192    assert e.value.response["Error"]["Message"].startswith("Could not find endpoint")
193
194
195def _create_model(boto_client, model_name):
196    resp = boto_client.create_model(
197        ModelName=model_name,
198        PrimaryContainer={
199            "Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1",
200            "ModelDataUrl": "s3://MyBucket/model.tar.gz",
201        },
202        ExecutionRoleArn=FAKE_ROLE_ARN,
203    )
204    assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200
205
206
207def _create_endpoint_config(boto_client, endpoint_config_name, model_name):
208    production_variants = [
209        {
210            "VariantName": "MyProductionVariant",
211            "ModelName": model_name,
212            "InitialInstanceCount": 1,
213            "InstanceType": "ml.t2.medium",
214        },
215    ]
216    resp = boto_client.create_endpoint_config(
217        EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants
218    )
219    resp["EndpointConfigArn"].should.match(
220        r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name)
221    )
222
223
224def _create_endpoint(boto_client, endpoint_name, endpoint_config_name):
225    resp = boto_client.create_endpoint(
226        EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
227    )
228    resp["EndpointArn"].should.match(
229        r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name)
230    )
231