1"""
2Connection library for AWS
3
4.. versionadded:: 2015.5.0
5
6This is a base library used by a number of AWS services.
7
8:depends: requests
9"""
10
11import binascii
12import hashlib
13import hmac
14import logging
15import random
16import re
17import time
18import urllib.parse
19import xml.etree.ElementTree as ET
20from datetime import datetime
21
22import salt.config
23import salt.utils.hashutils
24import salt.utils.xmlutil as xml
25
26try:
27    import requests
28
29    HAS_REQUESTS = True  # pylint: disable=W0612
30except ImportError:
31    HAS_REQUESTS = False  # pylint: disable=W0612
32
33# pylint: enable=import-error,redefined-builtin,no-name-in-module
34
35log = logging.getLogger(__name__)
36DEFAULT_LOCATION = "us-east-1"
37DEFAULT_AWS_API_VERSION = "2016-11-15"
38AWS_RETRY_CODES = [
39    "RequestLimitExceeded",
40    "InsufficientInstanceCapacity",
41    "InternalError",
42    "Unavailable",
43    "InsufficientAddressCapacity",
44    "InsufficientReservedInstanceCapacity",
45]
46AWS_METADATA_TIMEOUT = 3.05
47
48AWS_MAX_RETRIES = 7
49
50IROLE_CODE = "use-instance-role-credentials"
51__AccessKeyId__ = ""
52__SecretAccessKey__ = ""
53__Token__ = ""
54__Expiration__ = ""
55__Location__ = ""
56__AssumeCache__ = {}
57
58
59def sleep_exponential_backoff(attempts):
60    """
61    backoff an exponential amount of time to throttle requests
62    during "API Rate Exceeded" failures as suggested by the AWS documentation here:
63    https://docs.aws.amazon.com/AWSEC2/latest/APIReference/query-api-troubleshooting.html
64    and also here:
65    https://docs.aws.amazon.com/general/latest/gr/api-retries.html
66    Failure to implement this approach results in a failure rate of >30% when using salt-cloud with
67    "--parallel" when creating 50 or more instances with a fixed delay of 2 seconds.
68    A failure rate of >10% is observed when using the salt-api with an asynchronous client
69    specified (runner_async).
70    """
71    time.sleep(random.uniform(1, 2 ** attempts))
72
73
74def creds(provider):
75    """
76    Return the credentials for AWS signing.  This could be just the id and key
77    specified in the provider configuration, or if the id or key is set to the
78    literal string 'use-instance-role-credentials' creds will pull the instance
79    role credentials from the meta data, cache them, and provide them instead.
80    """
81    # Declare globals
82    global __AccessKeyId__, __SecretAccessKey__, __Token__, __Expiration__
83
84    ret_credentials = ()
85
86    # if id or key is 'use-instance-role-credentials', pull them from meta-data
87    ## if needed
88    if provider["id"] == IROLE_CODE or provider["key"] == IROLE_CODE:
89        # Check to see if we have cache credentials that are still good
90        if __Expiration__ != "":
91            timenow = datetime.utcnow()
92            timestamp = timenow.strftime("%Y-%m-%dT%H:%M:%SZ")
93            if timestamp < __Expiration__:
94                # Current timestamp less than expiration fo cached credentials
95                return __AccessKeyId__, __SecretAccessKey__, __Token__
96        # We don't have any cached credentials, or they are expired, get them
97
98        # Connections to instance meta-data must fail fast and never be proxied
99        try:
100            result = requests.get(
101                "http://169.254.169.254/latest/meta-data/iam/security-credentials/",
102                proxies={"http": ""},
103                timeout=AWS_METADATA_TIMEOUT,
104            )
105            result.raise_for_status()
106            role = result.text
107        except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError):
108            return provider["id"], provider["key"], ""
109
110        try:
111            result = requests.get(
112                "http://169.254.169.254/latest/meta-data/iam/security-credentials/{}".format(
113                    role
114                ),
115                proxies={"http": ""},
116                timeout=AWS_METADATA_TIMEOUT,
117            )
118            result.raise_for_status()
119        except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError):
120            return provider["id"], provider["key"], ""
121
122        data = result.json()
123        __AccessKeyId__ = data["AccessKeyId"]
124        __SecretAccessKey__ = data["SecretAccessKey"]
125        __Token__ = data["Token"]
126        __Expiration__ = data["Expiration"]
127
128        ret_credentials = __AccessKeyId__, __SecretAccessKey__, __Token__
129    else:
130        ret_credentials = provider["id"], provider["key"], ""
131
132    if provider.get("role_arn") is not None:
133        provider_shadow = provider.copy()
134        provider_shadow.pop("role_arn", None)
135        log.info("Assuming the role: %s", provider.get("role_arn"))
136        ret_credentials = assumed_creds(
137            provider_shadow, role_arn=provider.get("role_arn"), location="us-east-1"
138        )
139
140    return ret_credentials
141
142
143def sig2(method, endpoint, params, provider, aws_api_version):
144    """
145    Sign a query against AWS services using Signature Version 2 Signing
146    Process. This is documented at:
147
148    http://docs.aws.amazon.com/general/latest/gr/signature-version-2.html
149    """
150    timenow = datetime.utcnow()
151    timestamp = timenow.strftime("%Y-%m-%dT%H:%M:%SZ")
152
153    # Retrieve access credentials from meta-data, or use provided
154    access_key_id, secret_access_key, token = creds(provider)
155
156    params_with_headers = params.copy()
157    params_with_headers["AWSAccessKeyId"] = access_key_id
158    params_with_headers["SignatureVersion"] = "2"
159    params_with_headers["SignatureMethod"] = "HmacSHA256"
160    params_with_headers["Timestamp"] = "{}".format(timestamp)
161    params_with_headers["Version"] = aws_api_version
162    keys = sorted(params_with_headers.keys())
163    values = list(list(map(params_with_headers.get, keys)))
164    querystring = urllib.parse.urlencode(list(zip(keys, values)))
165
166    canonical = "{}\n{}\n/\n{}".format(
167        method.encode("utf-8"),
168        endpoint.encode("utf-8"),
169        querystring.encode("utf-8"),
170    )
171
172    hashed = hmac.new(secret_access_key, canonical, hashlib.sha256)
173    sig = binascii.b2a_base64(hashed.digest())
174    params_with_headers["Signature"] = sig.strip()
175
176    # Add in security token if we have one
177    if token != "":
178        params_with_headers["SecurityToken"] = token
179
180    return params_with_headers
181
182
183def assumed_creds(prov_dict, role_arn, location=None):
184    valid_session_name_re = re.compile("[^a-z0-9A-Z+=,.@-]")
185
186    # current time in epoch seconds
187    now = time.mktime(datetime.utcnow().timetuple())
188
189    for key, creds in __AssumeCache__.items():
190        if (creds["Expiration"] - now) <= 120:
191            __AssumeCache__.delete(key)
192
193    if role_arn in __AssumeCache__:
194        c = __AssumeCache__[role_arn]
195        return c["AccessKeyId"], c["SecretAccessKey"], c["SessionToken"]
196
197    version = "2011-06-15"
198    session_name = valid_session_name_re.sub(
199        "", salt.config.get_id({"root_dir": None})[0]
200    )[0:63]
201
202    headers, requesturl = sig4(
203        "GET",
204        "sts.amazonaws.com",
205        params={
206            "Version": version,
207            "Action": "AssumeRole",
208            "RoleSessionName": session_name,
209            "RoleArn": role_arn,
210            "Policy": (
211                '{"Version":"2012-10-17","Statement":[{"Sid":"Stmt1",'
212                ' "Effect":"Allow","Action":"*","Resource":"*"}]}'
213            ),
214            "DurationSeconds": "3600",
215        },
216        aws_api_version=version,
217        data="",
218        uri="/",
219        prov_dict=prov_dict,
220        product="sts",
221        location=location,
222        requesturl="https://sts.amazonaws.com/",
223    )
224    headers["Accept"] = "application/json"
225    result = requests.request("GET", requesturl, headers=headers, data="", verify=True)
226
227    if result.status_code >= 400:
228        log.info("AssumeRole response: %s", result.content)
229    result.raise_for_status()
230    resp = result.json()
231
232    data = resp["AssumeRoleResponse"]["AssumeRoleResult"]["Credentials"]
233    __AssumeCache__[role_arn] = data
234    return data["AccessKeyId"], data["SecretAccessKey"], data["SessionToken"]
235
236
237def sig4(
238    method,
239    endpoint,
240    params,
241    prov_dict,
242    aws_api_version=DEFAULT_AWS_API_VERSION,
243    location=None,
244    product="ec2",
245    uri="/",
246    requesturl=None,
247    data="",
248    headers=None,
249    role_arn=None,
250    payload_hash=None,
251):
252    """
253    Sign a query against AWS services using Signature Version 4 Signing
254    Process. This is documented at:
255
256    http://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html
257    http://docs.aws.amazon.com/general/latest/gr/sigv4-signed-request-examples.html
258    http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
259    """
260    timenow = datetime.utcnow()
261
262    # Retrieve access credentials from meta-data, or use provided
263    if role_arn is None:
264        access_key_id, secret_access_key, token = creds(prov_dict)
265    else:
266        access_key_id, secret_access_key, token = assumed_creds(
267            prov_dict, role_arn, location=location
268        )
269
270    if location is None:
271        location = get_region_from_metadata()
272    if location is None:
273        location = DEFAULT_LOCATION
274
275    params_with_headers = params.copy()
276    if product not in ("s3", "ssm"):
277        params_with_headers["Version"] = aws_api_version
278    keys = sorted(params_with_headers.keys())
279    values = list(map(params_with_headers.get, keys))
280    querystring = urllib.parse.urlencode(list(zip(keys, values))).replace("+", "%20")
281
282    amzdate = timenow.strftime("%Y%m%dT%H%M%SZ")
283    datestamp = timenow.strftime("%Y%m%d")
284    new_headers = {}
285    if isinstance(headers, dict):
286        new_headers = headers.copy()
287
288    # Create payload hash (hash of the request body content). For GET
289    # requests, the payload is an empty string ('').
290    if not payload_hash:
291        payload_hash = salt.utils.hashutils.sha256_digest(data)
292
293    new_headers["X-Amz-date"] = amzdate
294    new_headers["host"] = endpoint
295    new_headers["x-amz-content-sha256"] = payload_hash
296    a_canonical_headers = []
297    a_signed_headers = []
298
299    if token != "":
300        new_headers["X-Amz-security-token"] = token
301
302    for header in sorted(new_headers.keys(), key=str.lower):
303        lower_header = header.lower()
304        a_canonical_headers.append(
305            "{}:{}".format(lower_header, new_headers[header].strip())
306        )
307        a_signed_headers.append(lower_header)
308    canonical_headers = "\n".join(a_canonical_headers) + "\n"
309    signed_headers = ";".join(a_signed_headers)
310
311    algorithm = "AWS4-HMAC-SHA256"
312
313    # Combine elements to create create canonical request
314    canonical_request = "\n".join(
315        (method, uri, querystring, canonical_headers, signed_headers, payload_hash)
316    )
317
318    # Create the string to sign
319    credential_scope = "/".join((datestamp, location, product, "aws4_request"))
320    string_to_sign = "\n".join(
321        (
322            algorithm,
323            amzdate,
324            credential_scope,
325            salt.utils.hashutils.sha256_digest(canonical_request),
326        )
327    )
328
329    # Create the signing key using the function defined above.
330    signing_key = _sig_key(secret_access_key, datestamp, location, product)
331
332    # Sign the string_to_sign using the signing_key
333    signature = hmac.new(
334        signing_key, string_to_sign.encode("utf-8"), hashlib.sha256
335    ).hexdigest()
336
337    # Add signing information to the request
338    authorization_header = "{} Credential={}/{}, SignedHeaders={}, Signature={}".format(
339        algorithm,
340        access_key_id,
341        credential_scope,
342        signed_headers,
343        signature,
344    )
345
346    new_headers["Authorization"] = authorization_header
347
348    requesturl = "{}?{}".format(requesturl, querystring)
349    return new_headers, requesturl
350
351
352def _sign(key, msg):
353    """
354    Key derivation functions. See:
355
356    http://docs.aws.amazon.com/general/latest/gr/signature-v4-examples.html#signature-v4-examples-python
357    """
358    return hmac.new(key, msg.encode("utf-8"), hashlib.sha256).digest()
359
360
361def _sig_key(key, date_stamp, regionName, serviceName):
362    """
363    Get a signature key. See:
364
365    http://docs.aws.amazon.com/general/latest/gr/signature-v4-examples.html#signature-v4-examples-python
366    """
367    kDate = _sign(("AWS4" + key).encode("utf-8"), date_stamp)
368    if regionName:
369        kRegion = _sign(kDate, regionName)
370        kService = _sign(kRegion, serviceName)
371    else:
372        kService = _sign(kDate, serviceName)
373    kSigning = _sign(kService, "aws4_request")
374    return kSigning
375
376
377def query(
378    params=None,
379    setname=None,
380    requesturl=None,
381    location=None,
382    return_url=False,
383    return_root=False,
384    opts=None,
385    provider=None,
386    endpoint=None,
387    product="ec2",
388    sigver="2",
389):
390    """
391    Perform a query against AWS services using Signature Version 2 Signing
392    Process. This is documented at:
393
394    http://docs.aws.amazon.com/general/latest/gr/signature-version-2.html
395
396    Regions and endpoints are documented at:
397
398    http://docs.aws.amazon.com/general/latest/gr/rande.html
399
400    Default ``product`` is ``ec2``. Valid ``product`` names are:
401
402    .. code-block:: yaml
403
404        - autoscaling (Auto Scaling)
405        - cloudformation (CloudFormation)
406        - ec2 (Elastic Compute Cloud)
407        - elasticache (ElastiCache)
408        - elasticbeanstalk (Elastic BeanStalk)
409        - elasticloadbalancing (Elastic Load Balancing)
410        - elasticmapreduce (Elastic MapReduce)
411        - iam (Identity and Access Management)
412        - importexport (Import/Export)
413        - monitoring (CloudWatch)
414        - rds (Relational Database Service)
415        - simpledb (SimpleDB)
416        - sns (Simple Notification Service)
417        - sqs (Simple Queue Service)
418    """
419    if params is None:
420        params = {}
421
422    if opts is None:
423        opts = {}
424
425    function = opts.get("function", (None, product))
426    providers = opts.get("providers", {})
427
428    if provider is None:
429        prov_dict = providers.get(function[1], {}).get(product, {})
430        if prov_dict:
431            driver = list(list(prov_dict.keys()))[0]
432            provider = providers.get(driver, product)
433    else:
434        prov_dict = providers.get(provider, {}).get(product, {})
435
436    service_url = prov_dict.get("service_url", "amazonaws.com")
437
438    if not location:
439        location = get_location(opts, prov_dict)
440
441    if endpoint is None:
442        if not requesturl:
443            endpoint = prov_dict.get(
444                "endpoint", "{}.{}.{}".format(product, location, service_url)
445            )
446
447            requesturl = "https://{}/".format(endpoint)
448        else:
449            endpoint = urllib.parse.urlparse(requesturl).netloc
450            if endpoint == "":
451                endpoint_err = (
452                    "Could not find a valid endpoint in the "
453                    "requesturl: {}. Looking for something "
454                    "like https://some.aws.endpoint/?args".format(requesturl)
455                )
456                log.error(endpoint_err)
457                if return_url is True:
458                    return {"error": endpoint_err}, requesturl
459                return {"error": endpoint_err}
460
461    log.debug("Using AWS endpoint: %s", endpoint)
462    method = "GET"
463
464    aws_api_version = prov_dict.get(
465        "aws_api_version",
466        prov_dict.get("{}_api_version".format(product), DEFAULT_AWS_API_VERSION),
467    )
468
469    # Fallback to ec2's id & key if none is found, for this component
470    if not prov_dict.get("id", None):
471        prov_dict["id"] = providers.get(provider, {}).get("ec2", {}).get("id", {})
472        prov_dict["key"] = providers.get(provider, {}).get("ec2", {}).get("key", {})
473
474    if sigver == "4":
475        headers, requesturl = sig4(
476            method,
477            endpoint,
478            params,
479            prov_dict,
480            aws_api_version,
481            location,
482            product,
483            requesturl=requesturl,
484        )
485        params_with_headers = {}
486    else:
487        params_with_headers = sig2(method, endpoint, params, prov_dict, aws_api_version)
488        headers = {}
489
490    attempts = 0
491    while attempts < AWS_MAX_RETRIES:
492        log.debug("AWS Request: %s", requesturl)
493        log.trace("AWS Request Parameters: %s", params_with_headers)
494        try:
495            result = requests.get(
496                requesturl, headers=headers, params=params_with_headers
497            )
498            log.debug("AWS Response Status Code: %s", result.status_code)
499            log.trace("AWS Response Text: %s", result.text)
500            result.raise_for_status()
501            break
502        except requests.exceptions.HTTPError as exc:
503            root = ET.fromstring(exc.response.content)
504            data = xml.to_dict(root)
505
506            # check to see if we should retry the query
507            err_code = data.get("Errors", {}).get("Error", {}).get("Code", "")
508            if attempts < AWS_MAX_RETRIES and err_code and err_code in AWS_RETRY_CODES:
509                attempts += 1
510                log.error(
511                    "AWS Response Status Code and Error: [%s %s] %s; "
512                    "Attempts remaining: %s",
513                    exc.response.status_code,
514                    exc,
515                    data,
516                    attempts,
517                )
518                sleep_exponential_backoff(attempts)
519                continue
520
521            log.error(
522                "AWS Response Status Code and Error: [%s %s] %s",
523                exc.response.status_code,
524                exc,
525                data,
526            )
527            if return_url is True:
528                return {"error": data}, requesturl
529            return {"error": data}
530    else:
531        log.error(
532            "AWS Response Status Code and Error: [%s %s] %s",
533            exc.response.status_code,
534            exc,
535            data,
536        )
537        if return_url is True:
538            return {"error": data}, requesturl
539        return {"error": data}
540
541    root = ET.fromstring(result.text)
542    items = root[1]
543    if return_root is True:
544        items = root
545
546    if setname:
547        for idx, item in enumerate(root):
548            comps = item.tag.split("}")
549            if comps[1] == setname:
550                items = root[idx]
551
552    ret = []
553    for item in items:
554        ret.append(xml.to_dict(item))
555
556    if return_url is True:
557        return ret, requesturl
558
559    return ret
560
561
562def get_region_from_metadata():
563    """
564    Try to get region from instance identity document and cache it
565
566    .. versionadded:: 2015.5.6
567    """
568    global __Location__
569
570    if __Location__ == "do-not-get-from-metadata":
571        log.debug(
572            "Previously failed to get AWS region from metadata. Not trying again."
573        )
574        return None
575
576    # Cached region
577    if __Location__ != "":
578        return __Location__
579
580    try:
581        # Connections to instance meta-data must fail fast and never be proxied
582        result = requests.get(
583            "http://169.254.169.254/latest/dynamic/instance-identity/document",
584            proxies={"http": ""},
585            timeout=AWS_METADATA_TIMEOUT,
586        )
587    except requests.exceptions.RequestException:
588        log.warning("Failed to get AWS region from instance metadata.", exc_info=True)
589        # Do not try again
590        __Location__ = "do-not-get-from-metadata"
591        return None
592
593    try:
594        region = result.json()["region"]
595        __Location__ = region
596        return __Location__
597    except (ValueError, KeyError):
598        log.warning("Failed to decode JSON from instance metadata.")
599        return None
600
601    return None
602
603
604def get_location(opts=None, provider=None):
605    """
606    Return the region to use, in this order:
607        opts['location']
608        provider['location']
609        get_region_from_metadata()
610        DEFAULT_LOCATION
611    """
612    if opts is None:
613        opts = {}
614    ret = opts.get("location")
615    if ret is None and provider is not None:
616        ret = provider.get("location")
617    if ret is None:
618        ret = get_region_from_metadata()
619    if ret is None:
620        ret = DEFAULT_LOCATION
621    return ret
622