1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6import azure.cli.core._debug as _debug
7from azure.cli.core.extension import EXTENSIONS_MOD_PREFIX
8from azure.cli.core.profiles._shared import get_client_class, SDKProfile
9from azure.cli.core.profiles import ResourceType, CustomResourceType, get_api_version, get_sdk
10from azure.cli.core.util import get_az_user_agent, is_track2
11
12from knack.log import get_logger
13from knack.util import CLIError
14
15logger = get_logger(__name__)
16
17
18def _is_vendored_sdk_path(path_comps):
19    return len(path_comps) >= 5 and path_comps[4] == 'vendored_sdks'
20
21
22def resolve_client_arg_name(operation, kwargs):
23    if not isinstance(operation, str):
24        raise CLIError("operation should be type 'str'. Got '{}'".format(type(operation)))
25    if 'client_arg_name' in kwargs:
26        logger.info("Keyword 'client_arg_name' is deprecated and should be removed.")
27        return kwargs['client_arg_name']
28    path, op_path = operation.split('#', 1)
29
30    path_comps = path.split('.')
31    if path_comps[0] == 'azure':
32        if path_comps[1] != 'cli' or _is_vendored_sdk_path(path_comps):
33            # Public SDK: azure.mgmt.resource... (mgmt-plane) or azure.storage.blob... (data-plane)
34            # Vendored SDK: azure.cli.command_modules.keyvault.vendored_sdks...
35            client_arg_name = 'self'
36        else:
37            # CLI custom method: azure.cli.command_modules.resource...
38            client_arg_name = 'client'
39    elif path_comps[0].startswith(EXTENSIONS_MOD_PREFIX):
40        # for CLI extensions
41        # SDK method: the operation takes the form '<class name>.<method_name>'
42        # custom method: the operation takes the form '<method_name>'
43        op_comps = op_path.split('.')
44        client_arg_name = 'self' if len(op_comps) > 1 else 'client'
45    else:
46        raise ValueError('Unrecognized operation: {}'.format(operation))
47    return client_arg_name
48
49
50def get_mgmt_service_client(cli_ctx, client_or_resource_type, subscription_id=None, api_version=None,
51                            aux_subscriptions=None, aux_tenants=None, **kwargs):
52    """
53     :params subscription_id: the current account's subscription
54     :param aux_subscriptions: mainly for cross tenant scenarios, say vnet peering.
55    """
56    if not subscription_id and 'subscription_id' in cli_ctx.data:
57        subscription_id = cli_ctx.data['subscription_id']
58
59    sdk_profile = None
60    if isinstance(client_or_resource_type, (ResourceType, CustomResourceType)):
61        # Get the versioned client
62        client_type = get_client_class(client_or_resource_type)
63        api_version = api_version or get_api_version(cli_ctx, client_or_resource_type, as_sdk_profile=True)
64        if isinstance(api_version, SDKProfile):
65            sdk_profile = api_version.profile
66            api_version = None
67    else:
68        # Get the non-versioned client
69        client_type = client_or_resource_type
70    client, _ = _get_mgmt_service_client(cli_ctx, client_type, subscription_id=subscription_id,
71                                         api_version=api_version, sdk_profile=sdk_profile,
72                                         aux_subscriptions=aux_subscriptions,
73                                         aux_tenants=aux_tenants,
74                                         **kwargs)
75    return client
76
77
78def get_subscription_service_client(cli_ctx):
79    return _get_mgmt_service_client(cli_ctx, get_client_class(ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS),
80                                    subscription_bound=False,
81                                    api_version=get_api_version(cli_ctx, ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS))
82
83
84def configure_common_settings(cli_ctx, client):
85    client = _debug.change_ssl_cert_verification(client)
86
87    client.config.enable_http_logger = True
88
89    client.config.add_user_agent(get_az_user_agent())
90
91    try:
92        command_ext_name = cli_ctx.data['command_extension_name']
93        if command_ext_name:
94            client.config.add_user_agent("CliExtension/{}".format(command_ext_name))
95    except KeyError:
96        pass
97
98    # Prepare CommandName header
99    command_name_suffix = ';completer-request' if cli_ctx.data['completer_active'] else ''
100    cli_ctx.data['headers']['CommandName'] = "{}{}".format(cli_ctx.data['command'], command_name_suffix)
101
102    # Prepare ParameterSetName header
103    if cli_ctx.data.get('safe_params'):
104        cli_ctx.data['headers']['ParameterSetName'] = ' '.join(cli_ctx.data['safe_params'])
105
106    # Prepare x-ms-client-request-id header
107    client.config.generate_client_request_id = 'x-ms-client-request-id' not in cli_ctx.data['headers']
108
109    logger.debug("Adding custom headers to the client:")
110
111    for header, value in cli_ctx.data['headers'].items():
112        # msrest doesn't print custom headers in debug log, so CLI should do that
113        logger.debug("    '%s': '%s'", header, value)
114        # We are working with the autorest team to expose the add_header functionality of the generated client to avoid
115        # having to access private members
116        client._client.add_header(header, value)  # pylint: disable=protected-access
117
118
119def _prepare_client_kwargs_track2(cli_ctx):
120    """Prepare kwargs for Track 2 SDK client."""
121    client_kwargs = {}
122
123    # Prepare connection_verify to change SSL verification behavior, used by ConnectionConfiguration
124    client_kwargs.update(_debug.change_ssl_cert_verification_track2())
125
126    # Prepare User-Agent header, used by UserAgentPolicy
127    client_kwargs['user_agent'] = get_az_user_agent()
128
129    try:
130        command_ext_name = cli_ctx.data['command_extension_name']
131        if command_ext_name:
132            client_kwargs['user_agent'] += "CliExtension/{}".format(command_ext_name)
133    except KeyError:
134        pass
135
136    # Prepare custom headers, used by HeadersPolicy
137    headers = dict(cli_ctx.data['headers'])
138
139    # - Prepare CommandName header
140    command_name_suffix = ';completer-request' if cli_ctx.data['completer_active'] else ''
141    headers['CommandName'] = "{}{}".format(cli_ctx.data['command'], command_name_suffix)
142
143    # - Prepare ParameterSetName header
144    if cli_ctx.data.get('safe_params'):
145        headers['ParameterSetName'] = ' '.join(cli_ctx.data['safe_params'])
146
147    client_kwargs['headers'] = headers
148
149    # Prepare x-ms-client-request-id header, used by RequestIdPolicy
150    if 'x-ms-client-request-id' in cli_ctx.data['headers']:
151        client_kwargs['request_id'] = cli_ctx.data['headers']['x-ms-client-request-id']
152
153    # Replace NetworkTraceLoggingPolicy to redact 'Authorization' and 'x-ms-authorization-auxiliary' headers.
154    #   NetworkTraceLoggingPolicy: log raw network trace, with all headers.
155    from azure.cli.core.sdk.policies import SafeNetworkTraceLoggingPolicy
156    client_kwargs['logging_policy'] = SafeNetworkTraceLoggingPolicy()
157
158    # Disable ARMHttpLoggingPolicy.
159    #   ARMHttpLoggingPolicy: Only log allowed information.
160    from azure.core.pipeline.policies import SansIOHTTPPolicy
161    client_kwargs['http_logging_policy'] = SansIOHTTPPolicy()
162
163    return client_kwargs
164
165
166def _prepare_mgmt_client_kwargs_track2(cli_ctx, cred):
167    """Prepare kwargs for Track 2 SDK mgmt client."""
168    client_kwargs = _prepare_client_kwargs_track2(cli_ctx)
169
170    from azure.cli.core.util import resource_to_scopes
171    # Track 2 SDK maintains `scopes` and passes `scopes` to get_token.
172    scopes = resource_to_scopes(cli_ctx.cloud.endpoints.active_directory_resource_id)
173
174    client_kwargs['credential_scopes'] = scopes
175
176    # Track 2 currently lacks the ability to take external credentials.
177    #   https://github.com/Azure/azure-sdk-for-python/issues/8313
178    # As a temporary workaround, manually add external tokens to 'x-ms-authorization-auxiliary' header.
179    #   https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/authenticate-multi-tenant
180    if getattr(cred, "_external_tenant_token_retriever", None):
181        *_, external_tenant_tokens = cred.get_all_tokens(*scopes)
182        # Hard-code scheme to 'Bearer' as _BearerTokenCredentialPolicyBase._update_headers does.
183        client_kwargs['headers']['x-ms-authorization-auxiliary'] = \
184            ', '.join("Bearer {}".format(t[1]) for t in external_tenant_tokens)
185
186    return client_kwargs
187
188
189def _get_mgmt_service_client(cli_ctx,
190                             client_type,
191                             subscription_bound=True,
192                             subscription_id=None,
193                             api_version=None,
194                             base_url_bound=True,
195                             resource=None,
196                             sdk_profile=None,
197                             aux_subscriptions=None,
198                             aux_tenants=None,
199                             **kwargs):
200    from azure.cli.core._profile import Profile
201    logger.debug('Getting management service client client_type=%s', client_type.__name__)
202    resource = resource or cli_ctx.cloud.endpoints.active_directory_resource_id
203    profile = Profile(cli_ctx=cli_ctx)
204    cred, subscription_id, _ = profile.get_login_credentials(subscription_id=subscription_id, resource=resource,
205                                                             aux_subscriptions=aux_subscriptions,
206                                                             aux_tenants=aux_tenants)
207
208    client_kwargs = {}
209    if base_url_bound:
210        client_kwargs = {'base_url': cli_ctx.cloud.endpoints.resource_manager}
211    if api_version:
212        client_kwargs['api_version'] = api_version
213    if sdk_profile:
214        client_kwargs['profile'] = sdk_profile
215    if kwargs:
216        client_kwargs.update(kwargs)
217
218    if is_track2(client_type):
219        client_kwargs.update(_prepare_mgmt_client_kwargs_track2(cli_ctx, cred))
220
221    if subscription_bound:
222        client = client_type(cred, subscription_id, **client_kwargs)
223    else:
224        client = client_type(cred, **client_kwargs)
225
226    if not is_track2(client):
227        configure_common_settings(cli_ctx, client)
228
229    return client, subscription_id
230
231
232def get_data_service_client(cli_ctx, service_type, account_name, account_key, connection_string=None,
233                            sas_token=None, socket_timeout=None, token_credential=None, endpoint_suffix=None,
234                            location_mode=None):
235    logger.debug('Getting data service client service_type=%s', service_type.__name__)
236    try:
237        client_kwargs = {'account_name': account_name,
238                         'account_key': account_key,
239                         'connection_string': connection_string,
240                         'sas_token': sas_token}
241        if socket_timeout:
242            client_kwargs['socket_timeout'] = socket_timeout
243        if token_credential:
244            client_kwargs['token_credential'] = token_credential
245        if endpoint_suffix:
246            client_kwargs['endpoint_suffix'] = endpoint_suffix
247        client = service_type(**client_kwargs)
248        if location_mode:
249            client.location_mode = location_mode
250    except ValueError as exc:
251        _ERROR_STORAGE_MISSING_INFO = get_sdk(cli_ctx, ResourceType.DATA_STORAGE,
252                                              'common._error#_ERROR_STORAGE_MISSING_INFO')
253        if _ERROR_STORAGE_MISSING_INFO in str(exc):
254            raise ValueError(exc)
255        raise CLIError('Unable to obtain data client. Check your connection parameters.')
256    # TODO: enable Fiddler
257    client.request_callback = _get_add_headers_callback(cli_ctx)
258    return client
259
260
261def get_subscription_id(cli_ctx):
262    from azure.cli.core._profile import Profile
263    if not cli_ctx.data.get('subscription_id'):
264        cli_ctx.data['subscription_id'] = Profile(cli_ctx=cli_ctx).get_subscription_id()
265    return cli_ctx.data['subscription_id']
266
267
268def _get_add_headers_callback(cli_ctx):
269
270    def _add_headers(request):
271        agents = [request.headers['User-Agent'], get_az_user_agent()]
272        request.headers['User-Agent'] = ' '.join(agents)
273
274        try:
275            request.headers.update(cli_ctx.data['headers'])
276        except KeyError:
277            pass
278
279    return _add_headers
280
281
282prepare_client_kwargs_track2 = _prepare_client_kwargs_track2
283