1# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7# http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
13import boto.vendored.regions.regions as _regions
14
15
16class _CompatEndpointResolver(_regions.EndpointResolver):
17    """Endpoint resolver which handles boto2 compatibility concerns.
18
19    This is NOT intended for external use whatsoever.
20    """
21
22    _DEFAULT_SERVICE_RENAMES = {
23        # The botocore resolver is based on endpoint prefix.
24        # These don't always sync up to the name that boto2 uses.
25        # A mapping can be provided that handles the mapping between
26        # "service names" and endpoint prefixes.
27        'awslambda': 'lambda',
28        'cloudwatch': 'monitoring',
29        'ses': 'email',
30        'ec2containerservice': 'ecs',
31        'configservice': 'config',
32    }
33
34    def __init__(self, endpoint_data, service_rename_map=None):
35        """
36        :type endpoint_data: dict
37        :param endpoint_data: Regions and endpoints data in the same format
38            as is used by botocore / boto3.
39
40        :type service_rename_map: dict
41        :param service_rename_map: A mapping of boto2 service name to
42            endpoint prefix.
43        """
44        super(_CompatEndpointResolver, self).__init__(endpoint_data)
45        if service_rename_map is None:
46            service_rename_map = self._DEFAULT_SERVICE_RENAMES
47        # Mapping of boto2 service name to endpoint prefix
48        self._endpoint_prefix_map = service_rename_map
49        # Mapping of endpoint prefix to boto2 service name
50        self._service_name_map = dict(
51            (v, k) for k, v in service_rename_map.items())
52
53    def get_available_endpoints(self, service_name, partition_name='aws',
54                                allow_non_regional=False):
55        endpoint_prefix = self._endpoint_prefix(service_name)
56        return super(_CompatEndpointResolver, self).get_available_endpoints(
57            endpoint_prefix, partition_name, allow_non_regional)
58
59    def get_all_available_regions(self, service_name):
60        """Retrieve every region across partitions for a service."""
61        regions = set()
62        endpoint_prefix = self._endpoint_prefix(service_name)
63
64        # Get every region for every partition in the new endpoint format
65        for partition_name in self.get_available_partitions():
66            if self._is_global_service(service_name, partition_name):
67                # Global services are available in every region in the
68                # partition in which they are considered global.
69                partition = self._get_partition_data(partition_name)
70                regions.update(partition['regions'].keys())
71                continue
72            else:
73                regions.update(
74                    self.get_available_endpoints(
75                        endpoint_prefix, partition_name)
76                )
77
78        return list(regions)
79
80    def construct_endpoint(self, service_name, region_name=None):
81        endpoint_prefix = self._endpoint_prefix(service_name)
82        return super(_CompatEndpointResolver, self).construct_endpoint(
83            endpoint_prefix, region_name)
84
85    def get_available_services(self):
86        """Get a list of all the available services in the endpoints file(s)"""
87        services = set()
88
89        for partition in self._endpoint_data['partitions']:
90            services.update(partition['services'].keys())
91
92        return [self._service_name(s) for s in services]
93
94    def _is_global_service(self, service_name, partition_name='aws'):
95        """Determines whether a service uses a global endpoint.
96
97        In theory a service can be 'global' in one partition but regional in
98        another. In practice, each service is all global or all regional.
99        """
100        endpoint_prefix = self._endpoint_prefix(service_name)
101        partition = self._get_partition_data(partition_name)
102        service = partition['services'].get(endpoint_prefix, {})
103        return 'partitionEndpoint' in service
104
105    def _get_partition_data(self, partition_name):
106        """Get partition information for a particular partition.
107
108        This should NOT be used to get service endpoint data because it only
109        loads from the new endpoint format. It should only be used for
110        partition metadata and partition specific service metadata.
111
112        :type partition_name: str
113        :param partition_name: The name of the partition to search for.
114
115        :returns: Partition info from the new endpoints format.
116        :rtype: dict or None
117        """
118        for partition in self._endpoint_data['partitions']:
119            if partition['partition'] == partition_name:
120                return partition
121        raise ValueError(
122            "Could not find partition data for: %s" % partition_name)
123
124    def _endpoint_prefix(self, service_name):
125        """Given a boto2 service name, get the endpoint prefix."""
126        return self._endpoint_prefix_map.get(service_name, service_name)
127
128    def _service_name(self, endpoint_prefix):
129        """Given an endpoint prefix, get the boto2 service name."""
130        return self._service_name_map.get(endpoint_prefix, endpoint_prefix)
131
132
133class BotoEndpointResolver(object):
134    """Resolves endpoint hostnames for AWS services.
135
136    This is NOT intended for external use.
137    """
138
139    def __init__(self, endpoint_data, service_rename_map=None):
140        """
141        :type endpoint_data: dict
142        :param endpoint_data: Regions and endpoints data in the same format
143            as is used by botocore / boto3.
144
145        :type service_rename_map: dict
146        :param service_rename_map: A mapping of boto2 service name to
147            endpoint prefix.
148        """
149        self._resolver = _CompatEndpointResolver(
150            endpoint_data, service_rename_map)
151
152    def resolve_hostname(self, service_name, region_name):
153        """Resolve the hostname for a service in a particular region.
154
155        :type service_name: str
156        :param service_name: The service to look up.
157
158        :type region_name: str
159        :param region_name: The region to find the endpoint for.
160
161        :return: The hostname for the given service in the given region.
162        """
163        endpoint = self._resolver.construct_endpoint(service_name, region_name)
164        if endpoint is None:
165            return None
166        return endpoint.get('sslCommonName', endpoint['hostname'])
167
168    def get_all_available_regions(self, service_name):
169        """Get all the regions a service is available in.
170
171        :type service_name: str
172        :param service_name: The service to look up.
173
174        :rtype: list of str
175        :return: A list of all the regions the given service is available in.
176        """
177        return self._resolver.get_all_available_regions(service_name)
178
179    def get_available_services(self):
180        """Get all the services supported by the endpoint data.
181
182        :rtype: list of str
183        :return: A list of all the services explicitly contained within the
184            endpoint data provided during instantiation.
185        """
186        return self._resolver.get_available_services()
187
188
189class StaticEndpointBuilder(object):
190    """Builds a static mapping of endpoints in the legacy format."""
191
192    def __init__(self, resolver):
193        """
194        :type resolver: BotoEndpointResolver
195        :param resolver: An endpoint resolver.
196        """
197        self._resolver = resolver
198
199    def build_static_endpoints(self, service_names=None):
200        """Build a set of static endpoints in the legacy boto2 format.
201
202        :param service_names: The names of the services to build. They must
203            use the names that boto2 uses, not boto3, e.g "ec2containerservice"
204            and not "ecs". If no service names are provided, all available
205            services will be built.
206
207        :return: A dict consisting of::
208            {"service": {"region": "full.host.name"}}
209        """
210        if service_names is None:
211            service_names = self._resolver.get_available_services()
212
213        static_endpoints = {}
214        for name in service_names:
215            endpoints_for_service = self._build_endpoints_for_service(name)
216            if endpoints_for_service:
217                # It's possible that when we try to build endpoints for
218                # services we get an empty hash.  In that case we don't
219                # bother adding it to the final list of static endpoints.
220                static_endpoints[name] = endpoints_for_service
221        self._handle_special_cases(static_endpoints)
222        return static_endpoints
223
224    def _build_endpoints_for_service(self, service_name):
225        # Given a service name, 'ec2', build a dict of
226        # 'region' -> 'hostname'
227        endpoints = {}
228        regions = self._resolver.get_all_available_regions(service_name)
229        for region_name in regions:
230            endpoints[region_name] = self._resolver.resolve_hostname(
231                service_name, region_name)
232        return endpoints
233
234    def _handle_special_cases(self, static_endpoints):
235        # cloudsearchdomain endpoints use the exact same set of endpoints as
236        # cloudsearch.
237        if 'cloudsearch' in static_endpoints:
238            cloudsearch_endpoints = static_endpoints['cloudsearch']
239            static_endpoints['cloudsearchdomain'] = cloudsearch_endpoints
240