1# Copyright 2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7# http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
13"""Internal module to help with normalizing botocore client args.
14
15This module (and all function/classes within this module) should be
16considered internal, and *not* a public API.
17
18"""
19import copy
20import logging
21import socket
22
23import botocore.exceptions
24import botocore.serialize
25import botocore.utils
26from botocore.signers import RequestSigner
27from botocore.config import Config
28from botocore.endpoint import EndpointCreator
29
30
31logger = logging.getLogger(__name__)
32
33
34VALID_REGIONAL_ENDPOINTS_CONFIG = [
35    'legacy',
36    'regional',
37]
38LEGACY_GLOBAL_STS_REGIONS = [
39    'ap-northeast-1',
40    'ap-south-1',
41    'ap-southeast-1',
42    'ap-southeast-2',
43    'aws-global',
44    'ca-central-1',
45    'eu-central-1',
46    'eu-north-1',
47    'eu-west-1',
48    'eu-west-2',
49    'eu-west-3',
50    'sa-east-1',
51    'us-east-1',
52    'us-east-2',
53    'us-west-1',
54    'us-west-2',
55]
56
57
58class ClientArgsCreator(object):
59    def __init__(self, event_emitter, user_agent, response_parser_factory,
60                 loader, exceptions_factory, config_store):
61        self._event_emitter = event_emitter
62        self._user_agent = user_agent
63        self._response_parser_factory = response_parser_factory
64        self._loader = loader
65        self._exceptions_factory = exceptions_factory
66        self._config_store = config_store
67
68    def get_client_args(self, service_model, region_name, is_secure,
69                        endpoint_url, verify, credentials, scoped_config,
70                        client_config, endpoint_bridge):
71        final_args = self.compute_client_args(
72            service_model, client_config, endpoint_bridge, region_name,
73            endpoint_url, is_secure, scoped_config)
74
75        service_name = final_args['service_name'] # noqa
76        parameter_validation = final_args['parameter_validation']
77        endpoint_config = final_args['endpoint_config']
78        protocol = final_args['protocol']
79        config_kwargs = final_args['config_kwargs']
80        s3_config = final_args['s3_config']
81        partition = endpoint_config['metadata'].get('partition', None)
82        socket_options = final_args['socket_options']
83
84        signing_region = endpoint_config['signing_region']
85        endpoint_region_name = endpoint_config['region_name']
86
87        event_emitter = copy.copy(self._event_emitter)
88        signer = RequestSigner(
89            service_model.service_id, signing_region,
90            endpoint_config['signing_name'],
91            endpoint_config['signature_version'],
92            credentials, event_emitter
93        )
94
95        config_kwargs['s3'] = s3_config
96        new_config = Config(**config_kwargs)
97        endpoint_creator = EndpointCreator(event_emitter)
98
99        endpoint = endpoint_creator.create_endpoint(
100            service_model, region_name=endpoint_region_name,
101            endpoint_url=endpoint_config['endpoint_url'], verify=verify,
102            response_parser_factory=self._response_parser_factory,
103            max_pool_connections=new_config.max_pool_connections,
104            proxies=new_config.proxies,
105            timeout=(new_config.connect_timeout, new_config.read_timeout),
106            socket_options=socket_options,
107            client_cert=new_config.client_cert,
108            proxies_config=new_config.proxies_config)
109
110        serializer = botocore.serialize.create_serializer(
111            protocol, parameter_validation)
112        response_parser = botocore.parsers.create_parser(protocol)
113        return {
114            'serializer': serializer,
115            'endpoint': endpoint,
116            'response_parser': response_parser,
117            'event_emitter': event_emitter,
118            'request_signer': signer,
119            'service_model': service_model,
120            'loader': self._loader,
121            'client_config': new_config,
122            'partition': partition,
123            'exceptions_factory': self._exceptions_factory
124        }
125
126    def compute_client_args(self, service_model, client_config,
127                            endpoint_bridge, region_name, endpoint_url,
128                            is_secure, scoped_config):
129        service_name = service_model.endpoint_prefix
130        protocol = service_model.metadata['protocol']
131        parameter_validation = True
132        if client_config and not client_config.parameter_validation:
133            parameter_validation = False
134        elif scoped_config:
135            raw_value = scoped_config.get('parameter_validation')
136            if raw_value is not None:
137                parameter_validation = botocore.utils.ensure_boolean(raw_value)
138
139        # Override the user agent if specified in the client config.
140        user_agent = self._user_agent
141        if client_config is not None:
142            if client_config.user_agent is not None:
143                user_agent = client_config.user_agent
144            if client_config.user_agent_extra is not None:
145                user_agent += ' %s' % client_config.user_agent_extra
146
147        s3_config = self.compute_s3_config(client_config)
148        endpoint_config = self._compute_endpoint_config(
149            service_name=service_name,
150            region_name=region_name,
151            endpoint_url=endpoint_url,
152            is_secure=is_secure,
153            endpoint_bridge=endpoint_bridge,
154            s3_config=s3_config,
155        )
156        # Create a new client config to be passed to the client based
157        # on the final values. We do not want the user to be able
158        # to try to modify an existing client with a client config.
159        config_kwargs = dict(
160            region_name=endpoint_config['region_name'],
161            signature_version=endpoint_config['signature_version'],
162            user_agent=user_agent)
163        if client_config is not None:
164            config_kwargs.update(
165                connect_timeout=client_config.connect_timeout,
166                read_timeout=client_config.read_timeout,
167                max_pool_connections=client_config.max_pool_connections,
168                proxies=client_config.proxies,
169                proxies_config=client_config.proxies_config,
170                retries=client_config.retries,
171                client_cert=client_config.client_cert,
172                inject_host_prefix=client_config.inject_host_prefix,
173            )
174        self._compute_retry_config(config_kwargs)
175        s3_config = self.compute_s3_config(client_config)
176        return {
177            'service_name': service_name,
178            'parameter_validation': parameter_validation,
179            'user_agent': user_agent,
180            'endpoint_config': endpoint_config,
181            'protocol': protocol,
182            'config_kwargs': config_kwargs,
183            's3_config': s3_config,
184            'socket_options': self._compute_socket_options(scoped_config)
185        }
186
187    def compute_s3_config(self, client_config):
188        s3_configuration = self._config_store.get_config_variable('s3')
189
190        # Next specific client config values takes precedence over
191        # specific values in the scoped config.
192        if client_config is not None:
193            if client_config.s3 is not None:
194                if s3_configuration is None:
195                    s3_configuration = client_config.s3
196                else:
197                    # The current s3_configuration dictionary may be
198                    # from a source that only should be read from so
199                    # we want to be safe and just make a copy of it to modify
200                    # before it actually gets updated.
201                    s3_configuration = s3_configuration.copy()
202                    s3_configuration.update(client_config.s3)
203
204        return s3_configuration
205
206    def _compute_endpoint_config(self, service_name, region_name, endpoint_url,
207                                 is_secure, endpoint_bridge, s3_config):
208        resolve_endpoint_kwargs = {
209            'service_name': service_name,
210            'region_name': region_name,
211            'endpoint_url': endpoint_url,
212            'is_secure': is_secure,
213            'endpoint_bridge': endpoint_bridge,
214        }
215        if service_name == 's3':
216            return self._compute_s3_endpoint_config(
217                s3_config=s3_config, **resolve_endpoint_kwargs)
218        if service_name == 'sts':
219            return self._compute_sts_endpoint_config(**resolve_endpoint_kwargs)
220        return self._resolve_endpoint(**resolve_endpoint_kwargs)
221
222    def _compute_s3_endpoint_config(self, s3_config,
223                                    **resolve_endpoint_kwargs):
224        force_s3_global = self._should_force_s3_global(
225            resolve_endpoint_kwargs['region_name'], s3_config)
226        if force_s3_global:
227            resolve_endpoint_kwargs['region_name'] = None
228        endpoint_config = self._resolve_endpoint(**resolve_endpoint_kwargs)
229        self._set_region_if_custom_s3_endpoint(
230            endpoint_config, resolve_endpoint_kwargs['endpoint_bridge'])
231        # For backwards compatibility reasons, we want to make sure the
232        # client.meta.region_name will remain us-east-1 if we forced the
233        # endpoint to be the global region. Specifically, if this value
234        # changes to aws-global, it breaks logic where a user is checking
235        # for us-east-1 as the global endpoint such as in creating buckets.
236        if force_s3_global and endpoint_config['region_name'] == 'aws-global':
237            endpoint_config['region_name'] = 'us-east-1'
238        return endpoint_config
239
240    def _should_force_s3_global(self, region_name, s3_config):
241        s3_regional_config = 'legacy'
242        if s3_config and 'us_east_1_regional_endpoint' in s3_config:
243            s3_regional_config = s3_config['us_east_1_regional_endpoint']
244            self._validate_s3_regional_config(s3_regional_config)
245        return (
246            s3_regional_config == 'legacy' and
247            region_name in ['us-east-1', None]
248        )
249
250    def _validate_s3_regional_config(self, config_val):
251        if config_val not in VALID_REGIONAL_ENDPOINTS_CONFIG:
252            raise botocore.exceptions.\
253                InvalidS3UsEast1RegionalEndpointConfigError(
254                    s3_us_east_1_regional_endpoint_config=config_val)
255
256    def _set_region_if_custom_s3_endpoint(self, endpoint_config,
257                                          endpoint_bridge):
258        # If a user is providing a custom URL, the endpoint resolver will
259        # refuse to infer a signing region. If we want to default to s3v4,
260        # we have to account for this.
261        if endpoint_config['signing_region'] is None \
262                and endpoint_config['region_name'] is None:
263            endpoint = endpoint_bridge.resolve('s3')
264            endpoint_config['signing_region'] = endpoint['signing_region']
265            endpoint_config['region_name'] = endpoint['region_name']
266
267    def _compute_sts_endpoint_config(self, **resolve_endpoint_kwargs):
268        endpoint_config = self._resolve_endpoint(**resolve_endpoint_kwargs)
269        if self._should_set_global_sts_endpoint(
270                resolve_endpoint_kwargs['region_name'],
271                resolve_endpoint_kwargs['endpoint_url']):
272            self._set_global_sts_endpoint(
273                endpoint_config, resolve_endpoint_kwargs['is_secure'])
274        return endpoint_config
275
276    def _should_set_global_sts_endpoint(self, region_name, endpoint_url):
277        if endpoint_url:
278            return False
279        return (
280            self._get_sts_regional_endpoints_config() == 'legacy' and
281            region_name in LEGACY_GLOBAL_STS_REGIONS
282        )
283
284    def _get_sts_regional_endpoints_config(self):
285        sts_regional_endpoints_config = self._config_store.get_config_variable(
286            'sts_regional_endpoints')
287        if not sts_regional_endpoints_config:
288            sts_regional_endpoints_config = 'legacy'
289        if sts_regional_endpoints_config not in \
290                VALID_REGIONAL_ENDPOINTS_CONFIG:
291            raise botocore.exceptions.InvalidSTSRegionalEndpointsConfigError(
292                sts_regional_endpoints_config=sts_regional_endpoints_config)
293        return sts_regional_endpoints_config
294
295    def _set_global_sts_endpoint(self, endpoint_config, is_secure):
296        scheme = 'https' if is_secure else 'http'
297        endpoint_config['endpoint_url'] = '%s://sts.amazonaws.com' % scheme
298        endpoint_config['signing_region'] = 'us-east-1'
299
300    def _resolve_endpoint(self, service_name, region_name,
301                          endpoint_url, is_secure, endpoint_bridge):
302        return endpoint_bridge.resolve(
303            service_name, region_name, endpoint_url, is_secure)
304
305    def _compute_socket_options(self, scoped_config):
306        # This disables Nagle's algorithm and is the default socket options
307        # in urllib3.
308        socket_options = [(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)]
309        if scoped_config:
310            # Enables TCP Keepalive if specified in shared config file.
311            if self._ensure_boolean(scoped_config.get('tcp_keepalive', False)):
312                socket_options.append(
313                    (socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1))
314        return socket_options
315
316    def _compute_retry_config(self, config_kwargs):
317        self._compute_retry_max_attempts(config_kwargs)
318        self._compute_retry_mode(config_kwargs)
319
320    def _compute_retry_max_attempts(self, config_kwargs):
321        # There's a pre-existing max_attempts client config value that actually
322        # means max *retry* attempts.  There's also a `max_attempts` we pull
323        # from the config store that means *total attempts*, which includes the
324        # intitial request.  We can't change what `max_attempts` means in
325        # client config so we try to normalize everything to a new
326        # "total_max_attempts" variable.  We ensure that after this, the only
327        # configuration for "max attempts" is the 'total_max_attempts' key.
328        # An explicitly provided max_attempts in the client config
329        # overrides everything.
330        retries = config_kwargs.get('retries')
331        if retries is not None:
332            if 'total_max_attempts' in retries:
333                retries.pop('max_attempts', None)
334                return
335            if 'max_attempts' in retries:
336                value = retries.pop('max_attempts')
337                # client config max_attempts means total retries so we
338                # have to add one for 'total_max_attempts' to account
339                # for the initial request.
340                retries['total_max_attempts'] = value + 1
341                return
342        # Otherwise we'll check the config store which checks env vars,
343        # config files, etc.  There is no default value for max_attempts
344        # so if this returns None and we don't set a default value here.
345        max_attempts = self._config_store.get_config_variable('max_attempts')
346        if max_attempts is not None:
347            if retries is None:
348                retries = {}
349                config_kwargs['retries'] = retries
350            retries['total_max_attempts'] = max_attempts
351
352    def _compute_retry_mode(self, config_kwargs):
353        retries = config_kwargs.get('retries')
354        if retries is None:
355            retries = {}
356            config_kwargs['retries'] = retries
357        elif 'mode' in retries:
358            # If there's a retry mode explicitly set in the client config
359            # that overrides everything.
360            return
361        retry_mode = self._config_store.get_config_variable('retry_mode')
362        if retry_mode is None:
363            retry_mode = 'legacy'
364        retries['mode'] = retry_mode
365
366    def _ensure_boolean(self, val):
367        if isinstance(val, bool):
368            return val
369        else:
370            return val.lower() == 'true'
371