1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6try:
7    from urllib.parse import urlencode, urlparse, urlunparse
8except ImportError:
9    from urllib import urlencode
10    from urlparse import urlparse, urlunparse
11
12import time
13from json import loads
14from base64 import b64encode
15import requests
16from requests import RequestException
17from requests.utils import to_native_string
18from msrest.http_logger import log_request, log_response
19
20from knack.util import CLIError
21from knack.prompting import prompt, prompt_pass, NoTTYException
22from knack.log import get_logger
23
24from azure.cli.core.util import should_disable_connection_verify
25from azure.cli.core.cloud import CloudSuffixNotSetException
26from azure.cli.core._profile import _AZ_LOGIN_MESSAGE
27
28from ._client_factory import cf_acr_registries
29from ._constants import get_managed_sku
30from ._utils import get_registry_by_name, ResourceNotFound
31
32
33logger = get_logger(__name__)
34
35
36EMPTY_GUID = '00000000-0000-0000-0000-000000000000'
37ALLOWED_HTTP_METHOD = ['get', 'patch', 'put', 'delete']
38ACCESS_TOKEN_PERMISSION = ['pull', 'push', 'delete', 'push,pull', 'delete,pull']
39
40AAD_TOKEN_BASE_ERROR_MESSAGE = "Unable to get AAD authorization tokens with message"
41ADMIN_USER_BASE_ERROR_MESSAGE = "Unable to get admin user credentials with message"
42
43
44def _get_aad_token_after_challenge(cli_ctx,
45                                   token_params,
46                                   login_server,
47                                   only_refresh_token,
48                                   repository,
49                                   artifact_repository,
50                                   permission,
51                                   is_diagnostics_context):
52    authurl = urlparse(token_params['realm'])
53    authhost = urlunparse((authurl[0], authurl[1], '/oauth2/exchange', '', '', ''))
54
55    from azure.cli.core._profile import Profile
56    profile = Profile(cli_ctx=cli_ctx)
57    creds, _, tenant = profile.get_raw_token()
58
59    headers = {'Content-Type': 'application/x-www-form-urlencoded'}
60    content = {
61        'grant_type': 'access_token',
62        'service': token_params['service'],
63        'tenant': tenant,
64        'access_token': creds[1]
65    }
66
67    response = requests.post(authhost, urlencode(content), headers=headers,
68                             verify=(not should_disable_connection_verify()))
69
70    if response.status_code not in [200]:
71        from ._errors import CONNECTIVITY_REFRESH_TOKEN_ERROR
72        if is_diagnostics_context:
73            return CONNECTIVITY_REFRESH_TOKEN_ERROR.format_error_message(login_server, response.status_code)
74        raise CLIError(CONNECTIVITY_REFRESH_TOKEN_ERROR.format_error_message(login_server, response.status_code)
75                       .get_error_message())
76
77    refresh_token = loads(response.content.decode("utf-8"))["refresh_token"]
78    if only_refresh_token:
79        return refresh_token
80
81    authhost = urlunparse((authurl[0], authurl[1], '/oauth2/token', '', '', ''))
82
83    if repository:
84        scope = 'repository:{}:{}'.format(repository, permission)
85    elif artifact_repository:
86        scope = 'artifact-repository:{}:{}'.format(artifact_repository, permission)
87    else:
88        # catalog only has * as permission, even for a read operation
89        scope = 'registry:catalog:*'
90
91    content = {
92        'grant_type': 'refresh_token',
93        'service': login_server,
94        'scope': scope,
95        'refresh_token': refresh_token
96    }
97    response = requests.post(authhost, urlencode(content), headers=headers,
98                             verify=(not should_disable_connection_verify()))
99
100    if response.status_code not in [200]:
101        from ._errors import CONNECTIVITY_ACCESS_TOKEN_ERROR
102        if is_diagnostics_context:
103            return CONNECTIVITY_ACCESS_TOKEN_ERROR.format_error_message(login_server, response.status_code)
104        raise CLIError(CONNECTIVITY_ACCESS_TOKEN_ERROR.format_error_message(login_server, response.status_code)
105                       .get_error_message())
106
107    return loads(response.content.decode("utf-8"))["access_token"]
108
109
110def _get_aad_token(cli_ctx,
111                   login_server,
112                   only_refresh_token,
113                   repository=None,
114                   artifact_repository=None,
115                   permission=None,
116                   is_diagnostics_context=False):
117    """Obtains refresh and access tokens for an AAD-enabled registry.
118    :param str login_server: The registry login server URL to log in to
119    :param bool only_refresh_token: Whether to ask for only refresh token, or for both refresh and access tokens
120    :param str repository: Repository for which the access token is requested
121    :param str artifact_repository: Artifact repository for which the access token is requested
122    :param str permission: The requested permission on the repository, '*' or 'pull'
123    """
124    if repository and artifact_repository:
125        raise ValueError("Only one of repository and artifact_repository can be provided.")
126
127    if (repository or artifact_repository) and permission not in ACCESS_TOKEN_PERMISSION:
128        raise ValueError(
129            "Permission is required for a repository or artifact_repository. Allowed access token permission: {}"
130            .format(ACCESS_TOKEN_PERMISSION))
131
132    login_server = login_server.rstrip('/')
133
134    challenge = requests.get('https://' + login_server + '/v2/', verify=(not should_disable_connection_verify()))
135    if challenge.status_code not in [401] or 'WWW-Authenticate' not in challenge.headers:
136        from ._errors import CONNECTIVITY_CHALLENGE_ERROR
137        if is_diagnostics_context:
138            return CONNECTIVITY_CHALLENGE_ERROR.format_error_message(login_server)
139        raise CLIError(CONNECTIVITY_CHALLENGE_ERROR.format_error_message(login_server).get_error_message())
140
141    authenticate = challenge.headers['WWW-Authenticate']
142
143    tokens = authenticate.split(' ', 2)
144    if len(tokens) < 2 or tokens[0].lower() != 'bearer':
145        from ._errors import CONNECTIVITY_AAD_LOGIN_ERROR
146        if is_diagnostics_context:
147            return CONNECTIVITY_AAD_LOGIN_ERROR.format_error_message(login_server)
148        raise CLIError(CONNECTIVITY_AAD_LOGIN_ERROR.format_error_message(login_server).get_error_message())
149
150    token_params = {y[0]: y[1].strip('"') for y in
151                    (x.strip().split('=', 2) for x in tokens[1].split(','))}
152    if 'realm' not in token_params or 'service' not in token_params:
153        from ._errors import CONNECTIVITY_AAD_LOGIN_ERROR
154        if is_diagnostics_context:
155            return CONNECTIVITY_AAD_LOGIN_ERROR.format_error_message(login_server)
156        raise CLIError(CONNECTIVITY_AAD_LOGIN_ERROR.format_error_message(login_server).get_error_message())
157
158    return _get_aad_token_after_challenge(cli_ctx,
159                                          token_params,
160                                          login_server,
161                                          only_refresh_token,
162                                          repository,
163                                          artifact_repository,
164                                          permission,
165                                          is_diagnostics_context)
166
167
168def _get_credentials(cmd,  # pylint: disable=too-many-statements
169                     registry_name,
170                     tenant_suffix,
171                     username,
172                     password,
173                     only_refresh_token,
174                     repository=None,
175                     artifact_repository=None,
176                     permission=None):
177    """Try to get AAD authorization tokens or admin user credentials.
178    :param str registry_name: The name of container registry
179    :param str tenant_suffix: The registry login server tenant suffix
180    :param str username: The username used to log into the container registry
181    :param str password: The password used to log into the container registry
182    :param bool only_refresh_token: Whether to ask for only refresh token, or for both refresh and access tokens
183    :param str repository: Repository for which the access token is requested
184    :param str artifact_repository: Artifact repository for which the access token is requested
185    :param str permission: The requested permission on the repository, '*' or 'pull'
186    """
187    # Raise an error if password is specified but username isn't
188    if not username and password:
189        raise CLIError('Please also specify username if password is specified.')
190
191    cli_ctx = cmd.cli_ctx
192    resource_not_found, registry = None, None
193    try:
194        registry, resource_group_name = get_registry_by_name(cli_ctx, registry_name)
195        login_server = registry.login_server
196        if tenant_suffix:
197            logger.warning(
198                "Obtained registry login server '%s' from service. The specified suffix '%s' is ignored.",
199                login_server, tenant_suffix)
200    except (ResourceNotFound, CLIError) as e:
201        resource_not_found = str(e)
202        logger.debug("Could not get registry from service. Exception: %s", resource_not_found)
203        if not isinstance(e, ResourceNotFound) and _AZ_LOGIN_MESSAGE not in resource_not_found:
204            raise
205        # Try to use the pre-defined login server suffix to construct login server from registry name.
206        login_server_suffix = get_login_server_suffix(cli_ctx)
207        if not login_server_suffix:
208            raise
209        login_server = '{}{}{}'.format(
210            registry_name, '-{}'.format(tenant_suffix) if tenant_suffix else '', login_server_suffix).lower()
211
212    # Validate the login server is reachable
213    url = 'https://' + login_server + '/v2/'
214    try:
215        challenge = requests.get(url, verify=(not should_disable_connection_verify()))
216        if challenge.status_code in [403]:
217            raise CLIError("Looks like you don't have access to registry '{}'. "
218                           "Are firewalls and virtual networks enabled?".format(login_server))
219    except RequestException as e:
220        logger.debug("Could not connect to registry login server. Exception: %s", str(e))
221        if resource_not_found:
222            logger.warning("%s\nUsing '%s' as the default registry login server.", resource_not_found, login_server)
223        raise CLIError("Could not connect to the registry login server '{}'. ".format(login_server) +
224                       "Please verify that the registry exists and " +
225                       "the URL '{}' is reachable from your environment.".format(url))
226
227    # 1. if username was specified, verify that password was also specified
228    if username:
229        if not password:
230            try:
231                password = prompt_pass(msg='Password: ')
232            except NoTTYException:
233                raise CLIError('Please specify both username and password in non-interactive mode.')
234
235        return login_server, username, password
236
237    # 2. if we don't yet have credentials, attempt to get a refresh token
238    if not registry or registry.sku.name in get_managed_sku(cmd):
239        try:
240            return login_server, EMPTY_GUID, _get_aad_token(
241                cli_ctx, login_server, only_refresh_token, repository, artifact_repository, permission)
242        except CLIError as e:
243            logger.warning("%s: %s", AAD_TOKEN_BASE_ERROR_MESSAGE, str(e))
244
245    # 3. if we still don't have credentials, attempt to get the admin credentials (if enabled)
246    if registry:
247        if registry.admin_user_enabled:
248            try:
249                cred = cf_acr_registries(cli_ctx).list_credentials(resource_group_name, registry_name)
250                return login_server, cred.username, cred.passwords[0].value
251            except CLIError as e:
252                logger.warning("%s: %s", ADMIN_USER_BASE_ERROR_MESSAGE, str(e))
253        else:
254            logger.warning("%s: %s", ADMIN_USER_BASE_ERROR_MESSAGE, "Admin user is disabled.")
255    else:
256        logger.warning("%s: %s", ADMIN_USER_BASE_ERROR_MESSAGE, resource_not_found)
257
258    # 4. if we still don't have credentials, prompt the user
259    try:
260        username = prompt('Username: ')
261        password = prompt_pass(msg='Password: ')
262        return login_server, username, password
263    except NoTTYException:
264        raise CLIError(
265            'Unable to authenticate using AAD or admin login credentials. ' +
266            'Please specify both username and password in non-interactive mode.')
267
268    return login_server, None, None
269
270
271def get_login_credentials(cmd,
272                          registry_name,
273                          tenant_suffix=None,
274                          username=None,
275                          password=None):
276    """Try to get AAD authorization tokens or admin user credentials to log into a registry.
277    :param str registry_name: The name of container registry
278    :param str username: The username used to log into the container registry
279    :param str password: The password used to log into the container registry
280    """
281    return _get_credentials(cmd,
282                            registry_name,
283                            tenant_suffix,
284                            username,
285                            password,
286                            only_refresh_token=True)
287
288
289def get_access_credentials(cmd,
290                           registry_name,
291                           tenant_suffix=None,
292                           username=None,
293                           password=None,
294                           repository=None,
295                           artifact_repository=None,
296                           permission=None):
297    """Try to get AAD authorization tokens or admin user credentials to access a registry.
298    :param str registry_name: The name of container registry
299    :param str username: The username used to log into the container registry
300    :param str password: The password used to log into the container registry
301    :param str repository: Repository for which the access token is requested
302    :param str artifact_repository: Artifact repository for which the access token is requested
303    :param str permission: The requested permission on the repository
304    """
305    return _get_credentials(cmd,
306                            registry_name,
307                            tenant_suffix,
308                            username,
309                            password,
310                            only_refresh_token=False,
311                            repository=repository,
312                            artifact_repository=artifact_repository,
313                            permission=permission)
314
315
316def log_registry_response(response):
317    """Log the HTTP request and response of a registry API call.
318    :param Response response: The response object
319    """
320    log_request(None, response.request)
321    log_response(None, response.request, RegistryResponse(response.request, response))
322
323
324def get_login_server_suffix(cli_ctx):
325    """Get the Azure Container Registry login server suffix in the current cloud."""
326    try:
327        return cli_ctx.cloud.suffixes.acr_login_server_endpoint
328    except CloudSuffixNotSetException as e:
329        logger.debug("Could not get login server endpoint suffix. Exception: %s", str(e))
330        # Ignore the error if the suffix is not set, the caller should then try to get login server from server.
331        return None
332
333
334def _get_basic_auth_str(username, password):
335    return 'Basic ' + to_native_string(
336        b64encode(('%s:%s' % (username, password)).encode('latin1')).strip()
337    )
338
339
340def _get_bearer_auth_str(token):
341    return 'Bearer ' + token
342
343
344def get_authorization_header(username, password):
345    """Get the authorization header as Basic auth if username is provided, or Bearer auth otherwise
346    :param str username: The username used to log into the container registry
347    :param str password: The password used to log into the container registry
348    """
349    if username == EMPTY_GUID:
350        auth = _get_bearer_auth_str(password)
351    else:
352        auth = _get_basic_auth_str(username, password)
353    return {'Authorization': auth}
354
355
356def request_data_from_registry(http_method,
357                               login_server,
358                               path,
359                               username,
360                               password,
361                               result_index=None,
362                               json_payload=None,
363                               file_payload=None,
364                               params=None,
365                               retry_times=3,
366                               retry_interval=5):
367    if http_method not in ALLOWED_HTTP_METHOD:
368        raise ValueError("Allowed http method: {}".format(ALLOWED_HTTP_METHOD))
369
370    if json_payload and file_payload:
371        raise ValueError("One of json_payload and file_payload can be specified.")
372
373    if http_method in ['get', 'delete'] and (json_payload or file_payload):
374        raise ValueError("Empty payload is required for http method: {}".format(http_method))
375
376    if http_method in ['patch', 'put'] and not (json_payload or file_payload):
377        raise ValueError("Non-empty payload is required for http method: {}".format(http_method))
378
379    url = 'https://{}{}'.format(login_server, path)
380    headers = get_authorization_header(username, password)
381
382    for i in range(0, retry_times):
383        errorMessage = None
384        try:
385            if file_payload:
386                with open(file_payload, 'rb') as data_payload:
387                    response = requests.request(
388                        method=http_method,
389                        url=url,
390                        headers=headers,
391                        params=params,
392                        data=data_payload,
393                        verify=(not should_disable_connection_verify())
394                    )
395            else:
396                response = requests.request(
397                    method=http_method,
398                    url=url,
399                    headers=headers,
400                    params=params,
401                    json=json_payload,
402                    verify=(not should_disable_connection_verify())
403                )
404
405            log_registry_response(response)
406
407            if response.status_code == 200:
408                result = response.json()[result_index] if result_index else response.json()
409                next_link = response.headers['link'] if 'link' in response.headers else None
410                return result, next_link
411            elif response.status_code == 201 or response.status_code == 202:
412                result = None
413                try:
414                    result = response.json()[result_index] if result_index else response.json()
415                except ValueError as e:
416                    logger.debug('Response is empty or is not a valid json. Exception: %s', str(e))
417                return result, None
418            elif response.status_code == 204:
419                return None, None
420            elif response.status_code == 401:
421                raise RegistryException(
422                    parse_error_message('Authentication required.', response),
423                    response.status_code)
424            elif response.status_code == 404:
425                raise RegistryException(
426                    parse_error_message('The requested data does not exist.', response),
427                    response.status_code)
428            elif response.status_code == 405:
429                raise RegistryException(
430                    parse_error_message('This operation is not supported.', response),
431                    response.status_code)
432            elif response.status_code == 409:
433                raise RegistryException(
434                    parse_error_message('Failed to request data due to a conflict.', response),
435                    response.status_code)
436            else:
437                raise Exception(parse_error_message('Could not {} the requested data.'.format(http_method), response))
438        except CLIError:
439            raise
440        except Exception as e:  # pylint: disable=broad-except
441            errorMessage = str(e)
442            logger.debug('Retrying %s with exception %s', i + 1, errorMessage)
443            time.sleep(retry_interval)
444
445    raise CLIError(errorMessage)
446
447
448def parse_error_message(error_message, response):
449    import json
450    try:
451        server_message = json.loads(response.text)['errors'][0]['message']
452        error_message = 'Error: {}'.format(server_message) if server_message else error_message
453    except (ValueError, KeyError, TypeError, IndexError):
454        pass
455
456    if not error_message.endswith('.'):
457        error_message = '{}.'.format(error_message)
458
459    try:
460        correlation_id = response.headers['x-ms-correlation-request-id']
461        return '{} Correlation ID: {}.'.format(error_message, correlation_id)
462    except (KeyError, TypeError, AttributeError):
463        return error_message
464
465
466class RegistryException(CLIError):
467    def __init__(self, message, status_code):
468        super(RegistryException, self).__init__(message)
469        self.status_code = status_code
470
471
472class RegistryResponse(object):  # pylint: disable=too-few-public-methods
473    def __init__(self, request, internal_response):
474        self.request = request
475        self.internal_response = internal_response
476        self.status_code = internal_response.status_code
477        self.headers = internal_response.headers
478        self.encoding = internal_response.encoding
479        self.reason = internal_response.reason
480        self.content = internal_response.content
481
482    def text(self):
483        return self.content.decode(self.encoding or "utf-8")
484