1import datetime 2import boto3 3from botocore.exceptions import ClientError 4import sure # noqa # pylint: disable=unused-import 5 6from moto import mock_sagemaker 7from moto.sts.models import ACCOUNT_ID 8import pytest 9 10TEST_REGION_NAME = "us-east-1" 11FAKE_ROLE_ARN = "arn:aws:iam::{}:role/FakeRole".format(ACCOUNT_ID) 12GENERIC_TAGS_PARAM = [ 13 {"Key": "newkey1", "Value": "newval1"}, 14 {"Key": "newkey2", "Value": "newval2"}, 15] 16 17 18@mock_sagemaker 19def test_create_endpoint_config(): 20 sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) 21 22 model_name = "MyModel" 23 production_variants = [ 24 { 25 "VariantName": "MyProductionVariant", 26 "ModelName": model_name, 27 "InitialInstanceCount": 1, 28 "InstanceType": "ml.t2.medium", 29 }, 30 ] 31 32 endpoint_config_name = "MyEndpointConfig" 33 with pytest.raises(ClientError) as e: 34 sagemaker.create_endpoint_config( 35 EndpointConfigName=endpoint_config_name, 36 ProductionVariants=production_variants, 37 ) 38 assert e.value.response["Error"]["Message"].startswith("Could not find model") 39 40 _create_model(sagemaker, model_name) 41 resp = sagemaker.create_endpoint_config( 42 EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants 43 ) 44 resp["EndpointConfigArn"].should.match( 45 r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) 46 ) 47 48 resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name) 49 resp["EndpointConfigArn"].should.match( 50 r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) 51 ) 52 resp["EndpointConfigName"].should.equal(endpoint_config_name) 53 resp["ProductionVariants"].should.equal(production_variants) 54 55 56@mock_sagemaker 57def test_delete_endpoint_config(): 58 sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) 59 60 model_name = "MyModel" 61 _create_model(sagemaker, model_name) 62 63 endpoint_config_name = "MyEndpointConfig" 64 production_variants = [ 65 { 66 "VariantName": "MyProductionVariant", 67 "ModelName": model_name, 68 "InitialInstanceCount": 1, 69 "InstanceType": "ml.t2.medium", 70 }, 71 ] 72 73 resp = sagemaker.create_endpoint_config( 74 EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants 75 ) 76 resp["EndpointConfigArn"].should.match( 77 r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) 78 ) 79 80 resp = sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name) 81 resp["EndpointConfigArn"].should.match( 82 r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) 83 ) 84 85 resp = sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name) 86 with pytest.raises(ClientError) as e: 87 sagemaker.describe_endpoint_config(EndpointConfigName=endpoint_config_name) 88 assert e.value.response["Error"]["Message"].startswith( 89 "Could not find endpoint configuration" 90 ) 91 92 with pytest.raises(ClientError) as e: 93 sagemaker.delete_endpoint_config(EndpointConfigName=endpoint_config_name) 94 assert e.value.response["Error"]["Message"].startswith( 95 "Could not find endpoint configuration" 96 ) 97 98 99@mock_sagemaker 100def test_create_endpoint_invalid_instance_type(): 101 sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) 102 103 model_name = "MyModel" 104 _create_model(sagemaker, model_name) 105 106 instance_type = "InvalidInstanceType" 107 production_variants = [ 108 { 109 "VariantName": "MyProductionVariant", 110 "ModelName": model_name, 111 "InitialInstanceCount": 1, 112 "InstanceType": instance_type, 113 }, 114 ] 115 116 endpoint_config_name = "MyEndpointConfig" 117 with pytest.raises(ClientError) as e: 118 sagemaker.create_endpoint_config( 119 EndpointConfigName=endpoint_config_name, 120 ProductionVariants=production_variants, 121 ) 122 assert e.value.response["Error"]["Code"] == "ValidationException" 123 expected_message = "Value '{}' at 'instanceType' failed to satisfy constraint: Member must satisfy enum value set: [".format( 124 instance_type 125 ) 126 assert expected_message in e.value.response["Error"]["Message"] 127 128 129@mock_sagemaker 130def test_create_endpoint(): 131 sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) 132 133 endpoint_name = "MyEndpoint" 134 with pytest.raises(ClientError) as e: 135 sagemaker.create_endpoint( 136 EndpointName=endpoint_name, EndpointConfigName="NonexistentEndpointConfig" 137 ) 138 assert e.value.response["Error"]["Message"].startswith( 139 "Could not find endpoint configuration" 140 ) 141 142 model_name = "MyModel" 143 _create_model(sagemaker, model_name) 144 145 endpoint_config_name = "MyEndpointConfig" 146 _create_endpoint_config(sagemaker, endpoint_config_name, model_name) 147 148 resp = sagemaker.create_endpoint( 149 EndpointName=endpoint_name, 150 EndpointConfigName=endpoint_config_name, 151 Tags=GENERIC_TAGS_PARAM, 152 ) 153 resp["EndpointArn"].should.match( 154 r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name) 155 ) 156 157 resp = sagemaker.describe_endpoint(EndpointName=endpoint_name) 158 resp["EndpointArn"].should.match( 159 r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name) 160 ) 161 resp["EndpointName"].should.equal(endpoint_name) 162 resp["EndpointConfigName"].should.equal(endpoint_config_name) 163 resp["EndpointStatus"].should.equal("InService") 164 assert isinstance(resp["CreationTime"], datetime.datetime) 165 assert isinstance(resp["LastModifiedTime"], datetime.datetime) 166 resp["ProductionVariants"][0]["VariantName"].should.equal("MyProductionVariant") 167 168 resp = sagemaker.list_tags(ResourceArn=resp["EndpointArn"]) 169 assert resp["Tags"] == GENERIC_TAGS_PARAM 170 171 172@mock_sagemaker 173def test_delete_endpoint(): 174 sagemaker = boto3.client("sagemaker", region_name=TEST_REGION_NAME) 175 176 model_name = "MyModel" 177 _create_model(sagemaker, model_name) 178 179 endpoint_config_name = "MyEndpointConfig" 180 _create_endpoint_config(sagemaker, endpoint_config_name, model_name) 181 182 endpoint_name = "MyEndpoint" 183 _create_endpoint(sagemaker, endpoint_name, endpoint_config_name) 184 185 sagemaker.delete_endpoint(EndpointName=endpoint_name) 186 with pytest.raises(ClientError) as e: 187 sagemaker.describe_endpoint(EndpointName=endpoint_name) 188 assert e.value.response["Error"]["Message"].startswith("Could not find endpoint") 189 190 with pytest.raises(ClientError) as e: 191 sagemaker.delete_endpoint(EndpointName=endpoint_name) 192 assert e.value.response["Error"]["Message"].startswith("Could not find endpoint") 193 194 195def _create_model(boto_client, model_name): 196 resp = boto_client.create_model( 197 ModelName=model_name, 198 PrimaryContainer={ 199 "Image": "382416733822.dkr.ecr.us-east-1.amazonaws.com/factorization-machines:1", 200 "ModelDataUrl": "s3://MyBucket/model.tar.gz", 201 }, 202 ExecutionRoleArn=FAKE_ROLE_ARN, 203 ) 204 assert resp["ResponseMetadata"]["HTTPStatusCode"] == 200 205 206 207def _create_endpoint_config(boto_client, endpoint_config_name, model_name): 208 production_variants = [ 209 { 210 "VariantName": "MyProductionVariant", 211 "ModelName": model_name, 212 "InitialInstanceCount": 1, 213 "InstanceType": "ml.t2.medium", 214 }, 215 ] 216 resp = boto_client.create_endpoint_config( 217 EndpointConfigName=endpoint_config_name, ProductionVariants=production_variants 218 ) 219 resp["EndpointConfigArn"].should.match( 220 r"^arn:aws:sagemaker:.*:.*:endpoint-config/{}$".format(endpoint_config_name) 221 ) 222 223 224def _create_endpoint(boto_client, endpoint_name, endpoint_config_name): 225 resp = boto_client.create_endpoint( 226 EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name 227 ) 228 resp["EndpointArn"].should.match( 229 r"^arn:aws:sagemaker:.*:.*:endpoint/{}$".format(endpoint_name) 230 ) 231