1# -------------------------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for license information. 4# -------------------------------------------------------------------------------------------- 5 6import azure.cli.core._debug as _debug 7from azure.cli.core.extension import EXTENSIONS_MOD_PREFIX 8from azure.cli.core.profiles._shared import get_client_class, SDKProfile 9from azure.cli.core.profiles import ResourceType, CustomResourceType, get_api_version, get_sdk 10from azure.cli.core.util import get_az_user_agent, is_track2 11 12from knack.log import get_logger 13from knack.util import CLIError 14 15logger = get_logger(__name__) 16 17 18def _is_vendored_sdk_path(path_comps): 19 return len(path_comps) >= 5 and path_comps[4] == 'vendored_sdks' 20 21 22def resolve_client_arg_name(operation, kwargs): 23 if not isinstance(operation, str): 24 raise CLIError("operation should be type 'str'. Got '{}'".format(type(operation))) 25 if 'client_arg_name' in kwargs: 26 logger.info("Keyword 'client_arg_name' is deprecated and should be removed.") 27 return kwargs['client_arg_name'] 28 path, op_path = operation.split('#', 1) 29 30 path_comps = path.split('.') 31 if path_comps[0] == 'azure': 32 if path_comps[1] != 'cli' or _is_vendored_sdk_path(path_comps): 33 # Public SDK: azure.mgmt.resource... (mgmt-plane) or azure.storage.blob... (data-plane) 34 # Vendored SDK: azure.cli.command_modules.keyvault.vendored_sdks... 35 client_arg_name = 'self' 36 else: 37 # CLI custom method: azure.cli.command_modules.resource... 38 client_arg_name = 'client' 39 elif path_comps[0].startswith(EXTENSIONS_MOD_PREFIX): 40 # for CLI extensions 41 # SDK method: the operation takes the form '<class name>.<method_name>' 42 # custom method: the operation takes the form '<method_name>' 43 op_comps = op_path.split('.') 44 client_arg_name = 'self' if len(op_comps) > 1 else 'client' 45 else: 46 raise ValueError('Unrecognized operation: {}'.format(operation)) 47 return client_arg_name 48 49 50def get_mgmt_service_client(cli_ctx, client_or_resource_type, subscription_id=None, api_version=None, 51 aux_subscriptions=None, aux_tenants=None, **kwargs): 52 """ 53 :params subscription_id: the current account's subscription 54 :param aux_subscriptions: mainly for cross tenant scenarios, say vnet peering. 55 """ 56 if not subscription_id and 'subscription_id' in cli_ctx.data: 57 subscription_id = cli_ctx.data['subscription_id'] 58 59 sdk_profile = None 60 if isinstance(client_or_resource_type, (ResourceType, CustomResourceType)): 61 # Get the versioned client 62 client_type = get_client_class(client_or_resource_type) 63 api_version = api_version or get_api_version(cli_ctx, client_or_resource_type, as_sdk_profile=True) 64 if isinstance(api_version, SDKProfile): 65 sdk_profile = api_version.profile 66 api_version = None 67 else: 68 # Get the non-versioned client 69 client_type = client_or_resource_type 70 client, _ = _get_mgmt_service_client(cli_ctx, client_type, subscription_id=subscription_id, 71 api_version=api_version, sdk_profile=sdk_profile, 72 aux_subscriptions=aux_subscriptions, 73 aux_tenants=aux_tenants, 74 **kwargs) 75 return client 76 77 78def get_subscription_service_client(cli_ctx): 79 return _get_mgmt_service_client(cli_ctx, get_client_class(ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS), 80 subscription_bound=False, 81 api_version=get_api_version(cli_ctx, ResourceType.MGMT_RESOURCE_SUBSCRIPTIONS)) 82 83 84def configure_common_settings(cli_ctx, client): 85 client = _debug.change_ssl_cert_verification(client) 86 87 client.config.enable_http_logger = True 88 89 client.config.add_user_agent(get_az_user_agent()) 90 91 try: 92 command_ext_name = cli_ctx.data['command_extension_name'] 93 if command_ext_name: 94 client.config.add_user_agent("CliExtension/{}".format(command_ext_name)) 95 except KeyError: 96 pass 97 98 # Prepare CommandName header 99 command_name_suffix = ';completer-request' if cli_ctx.data['completer_active'] else '' 100 cli_ctx.data['headers']['CommandName'] = "{}{}".format(cli_ctx.data['command'], command_name_suffix) 101 102 # Prepare ParameterSetName header 103 if cli_ctx.data.get('safe_params'): 104 cli_ctx.data['headers']['ParameterSetName'] = ' '.join(cli_ctx.data['safe_params']) 105 106 # Prepare x-ms-client-request-id header 107 client.config.generate_client_request_id = 'x-ms-client-request-id' not in cli_ctx.data['headers'] 108 109 logger.debug("Adding custom headers to the client:") 110 111 for header, value in cli_ctx.data['headers'].items(): 112 # msrest doesn't print custom headers in debug log, so CLI should do that 113 logger.debug(" '%s': '%s'", header, value) 114 # We are working with the autorest team to expose the add_header functionality of the generated client to avoid 115 # having to access private members 116 client._client.add_header(header, value) # pylint: disable=protected-access 117 118 119def _prepare_client_kwargs_track2(cli_ctx): 120 """Prepare kwargs for Track 2 SDK client.""" 121 client_kwargs = {} 122 123 # Prepare connection_verify to change SSL verification behavior, used by ConnectionConfiguration 124 client_kwargs.update(_debug.change_ssl_cert_verification_track2()) 125 126 # Prepare User-Agent header, used by UserAgentPolicy 127 client_kwargs['user_agent'] = get_az_user_agent() 128 129 try: 130 command_ext_name = cli_ctx.data['command_extension_name'] 131 if command_ext_name: 132 client_kwargs['user_agent'] += "CliExtension/{}".format(command_ext_name) 133 except KeyError: 134 pass 135 136 # Prepare custom headers, used by HeadersPolicy 137 headers = dict(cli_ctx.data['headers']) 138 139 # - Prepare CommandName header 140 command_name_suffix = ';completer-request' if cli_ctx.data['completer_active'] else '' 141 headers['CommandName'] = "{}{}".format(cli_ctx.data['command'], command_name_suffix) 142 143 # - Prepare ParameterSetName header 144 if cli_ctx.data.get('safe_params'): 145 headers['ParameterSetName'] = ' '.join(cli_ctx.data['safe_params']) 146 147 client_kwargs['headers'] = headers 148 149 # Prepare x-ms-client-request-id header, used by RequestIdPolicy 150 if 'x-ms-client-request-id' in cli_ctx.data['headers']: 151 client_kwargs['request_id'] = cli_ctx.data['headers']['x-ms-client-request-id'] 152 153 # Replace NetworkTraceLoggingPolicy to redact 'Authorization' and 'x-ms-authorization-auxiliary' headers. 154 # NetworkTraceLoggingPolicy: log raw network trace, with all headers. 155 from azure.cli.core.sdk.policies import SafeNetworkTraceLoggingPolicy 156 client_kwargs['logging_policy'] = SafeNetworkTraceLoggingPolicy() 157 158 # Disable ARMHttpLoggingPolicy. 159 # ARMHttpLoggingPolicy: Only log allowed information. 160 from azure.core.pipeline.policies import SansIOHTTPPolicy 161 client_kwargs['http_logging_policy'] = SansIOHTTPPolicy() 162 163 return client_kwargs 164 165 166def _prepare_mgmt_client_kwargs_track2(cli_ctx, cred): 167 """Prepare kwargs for Track 2 SDK mgmt client.""" 168 client_kwargs = _prepare_client_kwargs_track2(cli_ctx) 169 170 from azure.cli.core.util import resource_to_scopes 171 # Track 2 SDK maintains `scopes` and passes `scopes` to get_token. 172 scopes = resource_to_scopes(cli_ctx.cloud.endpoints.active_directory_resource_id) 173 174 client_kwargs['credential_scopes'] = scopes 175 176 # Track 2 currently lacks the ability to take external credentials. 177 # https://github.com/Azure/azure-sdk-for-python/issues/8313 178 # As a temporary workaround, manually add external tokens to 'x-ms-authorization-auxiliary' header. 179 # https://docs.microsoft.com/en-us/azure/azure-resource-manager/management/authenticate-multi-tenant 180 if getattr(cred, "_external_tenant_token_retriever", None): 181 *_, external_tenant_tokens = cred.get_all_tokens(*scopes) 182 # Hard-code scheme to 'Bearer' as _BearerTokenCredentialPolicyBase._update_headers does. 183 client_kwargs['headers']['x-ms-authorization-auxiliary'] = \ 184 ', '.join("Bearer {}".format(t[1]) for t in external_tenant_tokens) 185 186 return client_kwargs 187 188 189def _get_mgmt_service_client(cli_ctx, 190 client_type, 191 subscription_bound=True, 192 subscription_id=None, 193 api_version=None, 194 base_url_bound=True, 195 resource=None, 196 sdk_profile=None, 197 aux_subscriptions=None, 198 aux_tenants=None, 199 **kwargs): 200 from azure.cli.core._profile import Profile 201 logger.debug('Getting management service client client_type=%s', client_type.__name__) 202 resource = resource or cli_ctx.cloud.endpoints.active_directory_resource_id 203 profile = Profile(cli_ctx=cli_ctx) 204 cred, subscription_id, _ = profile.get_login_credentials(subscription_id=subscription_id, resource=resource, 205 aux_subscriptions=aux_subscriptions, 206 aux_tenants=aux_tenants) 207 208 client_kwargs = {} 209 if base_url_bound: 210 client_kwargs = {'base_url': cli_ctx.cloud.endpoints.resource_manager} 211 if api_version: 212 client_kwargs['api_version'] = api_version 213 if sdk_profile: 214 client_kwargs['profile'] = sdk_profile 215 if kwargs: 216 client_kwargs.update(kwargs) 217 218 if is_track2(client_type): 219 client_kwargs.update(_prepare_mgmt_client_kwargs_track2(cli_ctx, cred)) 220 221 if subscription_bound: 222 client = client_type(cred, subscription_id, **client_kwargs) 223 else: 224 client = client_type(cred, **client_kwargs) 225 226 if not is_track2(client): 227 configure_common_settings(cli_ctx, client) 228 229 return client, subscription_id 230 231 232def get_data_service_client(cli_ctx, service_type, account_name, account_key, connection_string=None, 233 sas_token=None, socket_timeout=None, token_credential=None, endpoint_suffix=None, 234 location_mode=None): 235 logger.debug('Getting data service client service_type=%s', service_type.__name__) 236 try: 237 client_kwargs = {'account_name': account_name, 238 'account_key': account_key, 239 'connection_string': connection_string, 240 'sas_token': sas_token} 241 if socket_timeout: 242 client_kwargs['socket_timeout'] = socket_timeout 243 if token_credential: 244 client_kwargs['token_credential'] = token_credential 245 if endpoint_suffix: 246 client_kwargs['endpoint_suffix'] = endpoint_suffix 247 client = service_type(**client_kwargs) 248 if location_mode: 249 client.location_mode = location_mode 250 except ValueError as exc: 251 _ERROR_STORAGE_MISSING_INFO = get_sdk(cli_ctx, ResourceType.DATA_STORAGE, 252 'common._error#_ERROR_STORAGE_MISSING_INFO') 253 if _ERROR_STORAGE_MISSING_INFO in str(exc): 254 raise ValueError(exc) 255 raise CLIError('Unable to obtain data client. Check your connection parameters.') 256 # TODO: enable Fiddler 257 client.request_callback = _get_add_headers_callback(cli_ctx) 258 return client 259 260 261def get_subscription_id(cli_ctx): 262 from azure.cli.core._profile import Profile 263 if not cli_ctx.data.get('subscription_id'): 264 cli_ctx.data['subscription_id'] = Profile(cli_ctx=cli_ctx).get_subscription_id() 265 return cli_ctx.data['subscription_id'] 266 267 268def _get_add_headers_callback(cli_ctx): 269 270 def _add_headers(request): 271 agents = [request.headers['User-Agent'], get_az_user_agent()] 272 request.headers['User-Agent'] = ' '.join(agents) 273 274 try: 275 request.headers.update(cli_ctx.data['headers']) 276 except KeyError: 277 pass 278 279 return _add_headers 280 281 282prepare_client_kwargs_track2 = _prepare_client_kwargs_track2 283