1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
7    from urllib.parse import urlencode, urlparse, urlunparse
8except ImportError:
9    from urllib import urlencode
10    from urlparse import urlparse, urlunparse
12import time
13from json import loads
14from base64 import b64encode
15import requests
16from requests import RequestException
17from requests.utils import to_native_string
18from msrest.http_logger import log_request, log_response
20from knack.util import CLIError
21from knack.prompting import prompt, prompt_pass, NoTTYException
22from knack.log import get_logger
24from azure.cli.core.util import should_disable_connection_verify
25from azure.cli.core.cloud import CloudSuffixNotSetException
26from azure.cli.core._profile import _AZ_LOGIN_MESSAGE
28from ._client_factory import cf_acr_registries
29from ._constants import get_managed_sku
30from ._utils import get_registry_by_name, ResourceNotFound
33logger = get_logger(__name__)
36EMPTY_GUID = '00000000-0000-0000-0000-000000000000'
37ALLOWED_HTTP_METHOD = ['get', 'patch', 'put', 'delete']
38ACCESS_TOKEN_PERMISSION = ['pull', 'push', 'delete', 'push,pull', 'delete,pull']
40AAD_TOKEN_BASE_ERROR_MESSAGE = "Unable to get AAD authorization tokens with message"
41ADMIN_USER_BASE_ERROR_MESSAGE = "Unable to get admin user credentials with message"
44def _get_aad_token_after_challenge(cli_ctx,
45                                   token_params,
46                                   login_server,
47                                   only_refresh_token,
48                                   repository,
49                                   artifact_repository,
50                                   permission,
51                                   is_diagnostics_context):
52    authurl = urlparse(token_params['realm'])
53    authhost = urlunparse((authurl[0], authurl[1], '/oauth2/exchange', '', '', ''))
55    from azure.cli.core._profile import Profile
56    profile = Profile(cli_ctx=cli_ctx)
57    creds, _, tenant = profile.get_raw_token()
59    headers = {'Content-Type': 'application/x-www-form-urlencoded'}
60    content = {
61        'grant_type': 'access_token',
62        'service': token_params['service'],
63        'tenant': tenant,
64        'access_token': creds[1]
65    }
67    response = requests.post(authhost, urlencode(content), headers=headers,
68                             verify=(not should_disable_connection_verify()))
70    if response.status_code not in [200]:
71        from ._errors import CONNECTIVITY_REFRESH_TOKEN_ERROR
72        if is_diagnostics_context:
73            return CONNECTIVITY_REFRESH_TOKEN_ERROR.format_error_message(login_server, response.status_code)
74        raise CLIError(CONNECTIVITY_REFRESH_TOKEN_ERROR.format_error_message(login_server, response.status_code)
75                       .get_error_message())
77    refresh_token = loads(response.content.decode("utf-8"))["refresh_token"]
78    if only_refresh_token:
79        return refresh_token
81    authhost = urlunparse((authurl[0], authurl[1], '/oauth2/token', '', '', ''))
83    if repository:
84        scope = 'repository:{}:{}'.format(repository, permission)
85    elif artifact_repository:
86        scope = 'artifact-repository:{}:{}'.format(artifact_repository, permission)
87    else:
88        # catalog only has * as permission, even for a read operation
89        scope = 'registry:catalog:*'
91    content = {
92        'grant_type': 'refresh_token',
93        'service': login_server,
94        'scope': scope,
95        'refresh_token': refresh_token
96    }
97    response = requests.post(authhost, urlencode(content), headers=headers,
98                             verify=(not should_disable_connection_verify()))
100    if response.status_code not in [200]:
101        from ._errors import CONNECTIVITY_ACCESS_TOKEN_ERROR
102        if is_diagnostics_context:
103            return CONNECTIVITY_ACCESS_TOKEN_ERROR.format_error_message(login_server, response.status_code)
104        raise CLIError(CONNECTIVITY_ACCESS_TOKEN_ERROR.format_error_message(login_server, response.status_code)
105                       .get_error_message())
107    return loads(response.content.decode("utf-8"))["access_token"]
110def _get_aad_token(cli_ctx,
111                   login_server,
112                   only_refresh_token,
113                   repository=None,
114                   artifact_repository=None,
115                   permission=None,
116                   is_diagnostics_context=False):
117    """Obtains refresh and access tokens for an AAD-enabled registry.
118    :param str login_server: The registry login server URL to log in to
119    :param bool only_refresh_token: Whether to ask for only refresh token, or for both refresh and access tokens
120    :param str repository: Repository for which the access token is requested
121    :param str artifact_repository: Artifact repository for which the access token is requested
122    :param str permission: The requested permission on the repository, '*' or 'pull'
123    """
124    if repository and artifact_repository:
125        raise ValueError("Only one of repository and artifact_repository can be provided.")
127    if (repository or artifact_repository) and permission not in ACCESS_TOKEN_PERMISSION:
128        raise ValueError(
129            "Permission is required for a repository or artifact_repository. Allowed access token permission: {}"
130            .format(ACCESS_TOKEN_PERMISSION))
132    login_server = login_server.rstrip('/')
134    challenge = requests.get('https://' + login_server + '/v2/', verify=(not should_disable_connection_verify()))
135    if challenge.status_code not in [401] or 'WWW-Authenticate' not in challenge.headers:
136        from ._errors import CONNECTIVITY_CHALLENGE_ERROR
137        if is_diagnostics_context:
138            return CONNECTIVITY_CHALLENGE_ERROR.format_error_message(login_server)
139        raise CLIError(CONNECTIVITY_CHALLENGE_ERROR.format_error_message(login_server).get_error_message())
141    authenticate = challenge.headers['WWW-Authenticate']
143    tokens = authenticate.split(' ', 2)
144    if len(tokens) < 2 or tokens[0].lower() != 'bearer':
145        from ._errors import CONNECTIVITY_AAD_LOGIN_ERROR
146        if is_diagnostics_context:
147            return CONNECTIVITY_AAD_LOGIN_ERROR.format_error_message(login_server)
148        raise CLIError(CONNECTIVITY_AAD_LOGIN_ERROR.format_error_message(login_server).get_error_message())
150    token_params = {y[0]: y[1].strip('"') for y in
151                    (x.strip().split('=', 2) for x in tokens[1].split(','))}
152    if 'realm' not in token_params or 'service' not in token_params:
153        from ._errors import CONNECTIVITY_AAD_LOGIN_ERROR
154        if is_diagnostics_context:
155            return CONNECTIVITY_AAD_LOGIN_ERROR.format_error_message(login_server)
156        raise CLIError(CONNECTIVITY_AAD_LOGIN_ERROR.format_error_message(login_server).get_error_message())
158    return _get_aad_token_after_challenge(cli_ctx,
159                                          token_params,
160                                          login_server,
161                                          only_refresh_token,
162                                          repository,
163                                          artifact_repository,
164                                          permission,
165                                          is_diagnostics_context)
168def _get_credentials(cmd,  # pylint: disable=too-many-statements
169                     registry_name,
170                     tenant_suffix,
171                     username,
172                     password,
173                     only_refresh_token,
174                     repository=None,
175                     artifact_repository=None,
176                     permission=None):
177    """Try to get AAD authorization tokens or admin user credentials.
178    :param str registry_name: The name of container registry
179    :param str tenant_suffix: The registry login server tenant suffix
180    :param str username: The username used to log into the container registry
181    :param str password: The password used to log into the container registry
182    :param bool only_refresh_token: Whether to ask for only refresh token, or for both refresh and access tokens
183    :param str repository: Repository for which the access token is requested
184    :param str artifact_repository: Artifact repository for which the access token is requested
185    :param str permission: The requested permission on the repository, '*' or 'pull'
186    """
187    # Raise an error if password is specified but username isn't
188    if not username and password:
189        raise CLIError('Please also specify username if password is specified.')
191    cli_ctx = cmd.cli_ctx
192    resource_not_found, registry = None, None
193    try:
194        registry, resource_group_name = get_registry_by_name(cli_ctx, registry_name)
195        login_server = registry.login_server
196        if tenant_suffix:
197            logger.warning(
198                "Obtained registry login server '%s' from service. The specified suffix '%s' is ignored.",
199                login_server, tenant_suffix)
200    except (ResourceNotFound, CLIError) as e:
201        resource_not_found = str(e)
202        logger.debug("Could not get registry from service. Exception: %s", resource_not_found)
203        if not isinstance(e, ResourceNotFound) and _AZ_LOGIN_MESSAGE not in resource_not_found:
204            raise
205        # Try to use the pre-defined login server suffix to construct login server from registry name.
206        login_server_suffix = get_login_server_suffix(cli_ctx)
207        if not login_server_suffix:
208            raise
209        login_server = '{}{}{}'.format(
210            registry_name, '-{}'.format(tenant_suffix) if tenant_suffix else '', login_server_suffix).lower()
212    # Validate the login server is reachable
213    url = 'https://' + login_server + '/v2/'
214    try:
215        challenge = requests.get(url, verify=(not should_disable_connection_verify()))
216        if challenge.status_code in [403]:
217            raise CLIError("Looks like you don't have access to registry '{}'. "
218                           "Are firewalls and virtual networks enabled?".format(login_server))
219    except RequestException as e:
220        logger.debug("Could not connect to registry login server. Exception: %s", str(e))
221        if resource_not_found:
222            logger.warning("%s\nUsing '%s' as the default registry login server.", resource_not_found, login_server)
223        raise CLIError("Could not connect to the registry login server '{}'. ".format(login_server) +
224                       "Please verify that the registry exists and " +
225                       "the URL '{}' is reachable from your environment.".format(url))
227    # 1. if username was specified, verify that password was also specified
228    if username:
229        if not password:
230            try:
231                password = prompt_pass(msg='Password: ')
232            except NoTTYException:
233                raise CLIError('Please specify both username and password in non-interactive mode.')
235        return login_server, username, password
237    # 2. if we don't yet have credentials, attempt to get a refresh token
238    if not registry or registry.sku.name in get_managed_sku(cmd):
239        try:
240            return login_server, EMPTY_GUID, _get_aad_token(
241                cli_ctx, login_server, only_refresh_token, repository, artifact_repository, permission)
242        except CLIError as e:
243            logger.warning("%s: %s", AAD_TOKEN_BASE_ERROR_MESSAGE, str(e))
245    # 3. if we still don't have credentials, attempt to get the admin credentials (if enabled)
246    if registry:
247        if registry.admin_user_enabled:
248            try:
249                cred = cf_acr_registries(cli_ctx).list_credentials(resource_group_name, registry_name)
250                return login_server, cred.username, cred.passwords[0].value
251            except CLIError as e:
252                logger.warning("%s: %s", ADMIN_USER_BASE_ERROR_MESSAGE, str(e))
253        else:
254            logger.warning("%s: %s", ADMIN_USER_BASE_ERROR_MESSAGE, "Admin user is disabled.")
255    else:
256        logger.warning("%s: %s", ADMIN_USER_BASE_ERROR_MESSAGE, resource_not_found)
258    # 4. if we still don't have credentials, prompt the user
259    try:
260        username = prompt('Username: ')
261        password = prompt_pass(msg='Password: ')
262        return login_server, username, password
263    except NoTTYException:
264        raise CLIError(
265            'Unable to authenticate using AAD or admin login credentials. ' +
266            'Please specify both username and password in non-interactive mode.')
268    return login_server, None, None
271def get_login_credentials(cmd,
272                          registry_name,
273                          tenant_suffix=None,
274                          username=None,
275                          password=None):
276    """Try to get AAD authorization tokens or admin user credentials to log into a registry.
277    :param str registry_name: The name of container registry
278    :param str username: The username used to log into the container registry
279    :param str password: The password used to log into the container registry
280    """
281    return _get_credentials(cmd,
282                            registry_name,
283                            tenant_suffix,
284                            username,
285                            password,
286                            only_refresh_token=True)
289def get_access_credentials(cmd,
290                           registry_name,
291                           tenant_suffix=None,
292                           username=None,
293                           password=None,
294                           repository=None,
295                           artifact_repository=None,
296                           permission=None):
297    """Try to get AAD authorization tokens or admin user credentials to access a registry.
298    :param str registry_name: The name of container registry
299    :param str username: The username used to log into the container registry
300    :param str password: The password used to log into the container registry
301    :param str repository: Repository for which the access token is requested
302    :param str artifact_repository: Artifact repository for which the access token is requested
303    :param str permission: The requested permission on the repository
304    """
305    return _get_credentials(cmd,
306                            registry_name,
307                            tenant_suffix,
308                            username,
309                            password,
310                            only_refresh_token=False,
311                            repository=repository,
312                            artifact_repository=artifact_repository,
313                            permission=permission)
316def log_registry_response(response):
317    """Log the HTTP request and response of a registry API call.
318    :param Response response: The response object
319    """
320    log_request(None, response.request)
321    log_response(None, response.request, RegistryResponse(response.request, response))
324def get_login_server_suffix(cli_ctx):
325    """Get the Azure Container Registry login server suffix in the current cloud."""
326    try:
327        return cli_ctx.cloud.suffixes.acr_login_server_endpoint
328    except CloudSuffixNotSetException as e:
329        logger.debug("Could not get login server endpoint suffix. Exception: %s", str(e))
330        # Ignore the error if the suffix is not set, the caller should then try to get login server from server.
331        return None
334def _get_basic_auth_str(username, password):
335    return 'Basic ' + to_native_string(
336        b64encode(('%s:%s' % (username, password)).encode('latin1')).strip()
337    )
340def _get_bearer_auth_str(token):
341    return 'Bearer ' + token
344def get_authorization_header(username, password):
345    """Get the authorization header as Basic auth if username is provided, or Bearer auth otherwise
346    :param str username: The username used to log into the container registry
347    :param str password: The password used to log into the container registry
348    """
349    if username == EMPTY_GUID:
350        auth = _get_bearer_auth_str(password)
351    else:
352        auth = _get_basic_auth_str(username, password)
353    return {'Authorization': auth}
356def request_data_from_registry(http_method,
357                               login_server,
358                               path,
359                               username,
360                               password,
361                               result_index=None,
362                               json_payload=None,
363                               file_payload=None,
364                               params=None,
365                               retry_times=3,
366                               retry_interval=5):
367    if http_method not in ALLOWED_HTTP_METHOD:
368        raise ValueError("Allowed http method: {}".format(ALLOWED_HTTP_METHOD))
370    if json_payload and file_payload:
371        raise ValueError("One of json_payload and file_payload can be specified.")
373    if http_method in ['get', 'delete'] and (json_payload or file_payload):
374        raise ValueError("Empty payload is required for http method: {}".format(http_method))
376    if http_method in ['patch', 'put'] and not (json_payload or file_payload):
377        raise ValueError("Non-empty payload is required for http method: {}".format(http_method))
379    url = 'https://{}{}'.format(login_server, path)
380    headers = get_authorization_header(username, password)
382    for i in range(0, retry_times):
383        errorMessage = None
384        try:
385            if file_payload:
386                with open(file_payload, 'rb') as data_payload:
387                    response = requests.request(
388                        method=http_method,
389                        url=url,
390                        headers=headers,
391                        params=params,
392                        data=data_payload,
393                        verify=(not should_disable_connection_verify())
394                    )
395            else:
396                response = requests.request(
397                    method=http_method,
398                    url=url,
399                    headers=headers,
400                    params=params,
401                    json=json_payload,
402                    verify=(not should_disable_connection_verify())
403                )
405            log_registry_response(response)
407            if response.status_code == 200:
408                result = response.json()[result_index] if result_index else response.json()
409                next_link = response.headers['link'] if 'link' in response.headers else None
410                return result, next_link
411            elif response.status_code == 201 or response.status_code == 202:
412                result = None
413                try:
414                    result = response.json()[result_index] if result_index else response.json()
415                except ValueError as e:
416                    logger.debug('Response is empty or is not a valid json. Exception: %s', str(e))
417                return result, None
418            elif response.status_code == 204:
419                return None, None
420            elif response.status_code == 401:
421                raise RegistryException(
422                    parse_error_message('Authentication required.', response),
423                    response.status_code)
424            elif response.status_code == 404:
425                raise RegistryException(
426                    parse_error_message('The requested data does not exist.', response),
427                    response.status_code)
428            elif response.status_code == 405:
429                raise RegistryException(
430                    parse_error_message('This operation is not supported.', response),
431                    response.status_code)
432            elif response.status_code == 409:
433                raise RegistryException(
434                    parse_error_message('Failed to request data due to a conflict.', response),
435                    response.status_code)
436            else:
437                raise Exception(parse_error_message('Could not {} the requested data.'.format(http_method), response))
438        except CLIError:
439            raise
440        except Exception as e:  # pylint: disable=broad-except
441            errorMessage = str(e)
442            logger.debug('Retrying %s with exception %s', i + 1, errorMessage)
443            time.sleep(retry_interval)
445    raise CLIError(errorMessage)
448def parse_error_message(error_message, response):
449    import json
450    try:
451        server_message = json.loads(response.text)['errors'][0]['message']
452        error_message = 'Error: {}'.format(server_message) if server_message else error_message
453    except (ValueError, KeyError, TypeError, IndexError):
454        pass
456    if not error_message.endswith('.'):
457        error_message = '{}.'.format(error_message)
459    try:
460        correlation_id = response.headers['x-ms-correlation-request-id']
461        return '{} Correlation ID: {}.'.format(error_message, correlation_id)
462    except (KeyError, TypeError, AttributeError):
463        return error_message
466class RegistryException(CLIError):
467    def __init__(self, message, status_code):
468        super(RegistryException, self).__init__(message)
469        self.status_code = status_code
472class RegistryResponse(object):  # pylint: disable=too-few-public-methods
473    def __init__(self, request, internal_response):
474        self.request = request
475        self.internal_response = internal_response
476        self.status_code = internal_response.status_code
477        self.headers = internal_response.headers
478        self.encoding = internal_response.encoding
479        self.reason = internal_response.reason
480        self.content = internal_response.content
482    def text(self):
483        return self.content.decode(self.encoding or "utf-8")