1# -*- coding: utf-8 -*- #
2# Copyright 2020 Google LLC. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#    http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Utilities for operating on endpoints for different regions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import unicode_literals
20
21import contextlib
22
23from googlecloudsdk.api_lib.util import apis
24from googlecloudsdk.command_lib.ai import constants
25from googlecloudsdk.core import log
26from googlecloudsdk.core import properties
27from six.moves.urllib import parse
28
29
30def DeriveAiplatformRegionalEndpoint(endpoint, region, is_prediction=False):
31  """Adds region as a prefix of the base url."""
32  scheme, netloc, path, params, query, fragment = parse.urlparse(endpoint)
33  if netloc.startswith('aiplatform'):
34    if is_prediction:
35      netloc = '{}-prediction-{}'.format(region, netloc)
36    else:
37      netloc = '{}-{}'.format(region, netloc)
38  return parse.urlunparse((scheme, netloc, path, params, query, fragment))
39
40
41@contextlib.contextmanager
42def AiplatformEndpointOverrides(version, region, is_prediction=False):
43  """Context manager to override the AI Platform endpoints for a while.
44
45  Raises an error if
46  region is not set.
47
48  Args:
49    version: str, implies the version that the endpoint will use.
50    region: str, region of the AI Platform stack.
51    is_prediction: bool, it's for prediction endpoint or not.
52
53  Yields:
54    None
55  """
56  used_endpoint = GetEffectiveEndpoint(version=version, region=region,
57                                       is_prediction=is_prediction)
58  log.status.Print('Using endpoint [{}]'.format(used_endpoint))
59  properties.VALUES.api_endpoint_overrides.aiplatform.Set(used_endpoint)
60  yield
61
62
63def GetEffectiveEndpoint(version, region, is_prediction=False):
64  """Returns regional AI Platform endpoint, or raise an error if the region not set."""
65  endpoint = apis.GetEffectiveApiEndpoint(
66      constants.AI_PLATFORM_API_NAME,
67      constants.AI_PLATFORM_API_VERSION[version])
68  return DeriveAiplatformRegionalEndpoint(
69      endpoint, region, is_prediction=is_prediction)
70