1# Copyright 2013 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License"). You
4# may not use this file except in compliance with the License. A copy of
5# the License is located at
6#
7#     http://aws.amazon.com/apache2.0/
8#
9# or in the "license" file accompanying this file. This file is
10# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11# ANY KIND, either express or implied. See the License for the specific
12# language governing permissions and limitations under the License.
13"""
14Utility functions to make it easier to work with customizations.
15
16"""
17import copy
18import sys
19
20from botocore.exceptions import ClientError
21
22
23def rename_argument(argument_table, existing_name, new_name):
24    current = argument_table[existing_name]
25    argument_table[new_name] = current
26    current.name = new_name
27    del argument_table[existing_name]
28
29
30def _copy_argument(argument_table, current_name, copy_name):
31    current = argument_table[current_name]
32    copy_arg = copy.copy(current)
33    copy_arg.name = copy_name
34    argument_table[copy_name] = copy_arg
35    return copy_arg
36
37
38def make_hidden_alias(argument_table, existing_name, alias_name):
39    """Create a hidden alias for an existing argument.
40
41    This will copy an existing argument object in an arg table,
42    and add a new entry to the arg table with a different name.
43    The new argument will also be undocumented.
44
45    This is needed if you want to check an existing argument,
46    but you still need the other one to work for backwards
47    compatibility reasons.
48
49    """
50    current = argument_table[existing_name]
51    copy_arg = _copy_argument(argument_table, existing_name, alias_name)
52    copy_arg._UNDOCUMENTED = True
53    if current.required:
54        # If the current argument is required, then
55        # we'll mark both as not required, but
56        # flag _DOCUMENT_AS_REQUIRED so our doc gen
57        # knows to still document this argument as required.
58        copy_arg.required = False
59        current.required = False
60        current._DOCUMENT_AS_REQUIRED = True
61
62
63def rename_command(command_table, existing_name, new_name):
64    current = command_table[existing_name]
65    command_table[new_name] = current
66    current.name = new_name
67    del command_table[existing_name]
68
69
70def alias_command(command_table, existing_name, new_name):
71    """Moves an argument to a new name, keeping the old as a hidden alias.
72
73    :type command_table: dict
74    :param command_table: The full command table for the CLI or a service.
75
76    :type existing_name: str
77    :param existing_name: The current name of the command.
78
79    :type new_name: str
80    :param new_name: The new name for the command.
81    """
82    current = command_table[existing_name]
83    _copy_argument(command_table, existing_name, new_name)
84    current._UNDOCUMENTED = True
85
86
87def make_hidden_command_alias(command_table, existing_name, alias_name):
88    """Create a hidden alias for an exiting command.
89
90    This will copy an existing command object in a command table and add a new
91    entry to the command table with a different name. The new command will
92    be undocumented.
93
94    This is needed if you want to change an existing command, but you still
95    need the old name to work for backwards compatibility reasons.
96
97    :type command_table: dict
98    :param command_table: The full command table for the CLI or a service.
99
100    :type existing_name: str
101    :param existing_name: The current name of the command.
102
103    :type alias_name: str
104    :param alias_name: The new name for the command.
105    """
106    new = _copy_argument(command_table, existing_name, alias_name)
107    new._UNDOCUMENTED = True
108
109
110def validate_mutually_exclusive_handler(*groups):
111    def _handler(parsed_args, **kwargs):
112        return validate_mutually_exclusive(parsed_args, *groups)
113    return _handler
114
115
116def validate_mutually_exclusive(parsed_args, *groups):
117    """Validate mututally exclusive groups in the parsed args."""
118    args_dict = vars(parsed_args)
119    all_args = set(arg for group in groups for arg in group)
120    if not any(k in all_args for k in args_dict if args_dict[k] is not None):
121        # If none of the specified args are in a mutually exclusive group
122        # there is nothing left to validate.
123        return
124    current_group = None
125    for key in [k for k in args_dict if args_dict[k] is not None]:
126        key_group = _get_group_for_key(key, groups)
127        if key_group is None:
128            # If they key is not part of a mutex group, we can move on.
129            continue
130        if current_group is None:
131            current_group = key_group
132        elif not key_group == current_group:
133            raise ValueError('The key "%s" cannot be specified when one '
134                             'of the following keys are also specified: '
135                             '%s' % (key, ', '.join(current_group)))
136
137
138def _get_group_for_key(key, groups):
139    for group in groups:
140        if key in group:
141            return group
142
143
144def s3_bucket_exists(s3_client, bucket_name):
145    bucket_exists = True
146    try:
147        # See if the bucket exists by running a head bucket
148        s3_client.head_bucket(Bucket=bucket_name)
149    except ClientError as e:
150        # If a client error is thrown. Check that it was a 404 error.
151        # If it was a 404 error, than the bucket does not exist.
152        error_code = int(e.response['Error']['Code'])
153        if error_code == 404:
154            bucket_exists = False
155    return bucket_exists
156
157
158def create_client_from_parsed_globals(session, service_name, parsed_globals,
159                                      overrides=None):
160    """Creates a service client, taking parsed_globals into account
161
162    Any values specified in overrides will override the returned dict. Note
163    that this override occurs after 'region' from parsed_globals has been
164    translated into 'region_name' in the resulting dict.
165    """
166    client_args = {}
167    if 'region' in parsed_globals:
168        client_args['region_name'] = parsed_globals.region
169    if 'endpoint_url' in parsed_globals:
170        client_args['endpoint_url'] = parsed_globals.endpoint_url
171    if 'verify_ssl' in parsed_globals:
172        client_args['verify'] = parsed_globals.verify_ssl
173    if overrides:
174        client_args.update(overrides)
175    return session.create_client(service_name, **client_args)
176
177
178def uni_print(statement, out_file=None):
179    """
180    This function is used to properly write unicode to a file, usually
181    stdout or stdderr.  It ensures that the proper encoding is used if the
182    statement is not a string type.
183    """
184    if out_file is None:
185        out_file = sys.stdout
186    try:
187        # Otherwise we assume that out_file is a
188        # text writer type that accepts str/unicode instead
189        # of bytes.
190        out_file.write(statement)
191    except UnicodeEncodeError:
192        # Some file like objects like cStringIO will
193        # try to decode as ascii on python2.
194        #
195        # This can also fail if our encoding associated
196        # with the text writer cannot encode the unicode
197        # ``statement`` we've been given.  This commonly
198        # happens on windows where we have some S3 key
199        # previously encoded with utf-8 that can't be
200        # encoded using whatever codepage the user has
201        # configured in their console.
202        #
203        # At this point we've already failed to do what's
204        # been requested.  We now try to make a best effort
205        # attempt at printing the statement to the outfile.
206        # We're using 'ascii' as the default because if the
207        # stream doesn't give us any encoding information
208        # we want to pick an encoding that has the highest
209        # chance of printing successfully.
210        new_encoding = getattr(out_file, 'encoding', 'ascii')
211        # When the output of the aws command is being piped,
212        # ``sys.stdout.encoding`` is ``None``.
213        if new_encoding is None:
214            new_encoding = 'ascii'
215        new_statement = statement.encode(
216            new_encoding, 'replace').decode(new_encoding)
217        out_file.write(new_statement)
218    out_file.flush()
219
220
221def get_policy_arn_suffix(region):
222    """Method to return region value as expected by policy arn"""
223    region_string = region.lower()
224    if region_string.startswith("cn-"):
225        return "aws-cn"
226    elif region_string.startswith("us-gov"):
227        return "aws-us-gov"
228    else:
229        return "aws"
230