1# -*- coding: utf-8 -*- # 2# Copyright 2020 Google LLC. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15"""Utilities for operating on endpoints for different regions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import unicode_literals 20 21import contextlib 22 23from googlecloudsdk.api_lib.util import apis 24from googlecloudsdk.command_lib.ai import constants 25from googlecloudsdk.core import log 26from googlecloudsdk.core import properties 27from six.moves.urllib import parse 28 29 30def DeriveAiplatformRegionalEndpoint(endpoint, region, is_prediction=False): 31 """Adds region as a prefix of the base url.""" 32 scheme, netloc, path, params, query, fragment = parse.urlparse(endpoint) 33 if netloc.startswith('aiplatform'): 34 if is_prediction: 35 netloc = '{}-prediction-{}'.format(region, netloc) 36 else: 37 netloc = '{}-{}'.format(region, netloc) 38 return parse.urlunparse((scheme, netloc, path, params, query, fragment)) 39 40 41@contextlib.contextmanager 42def AiplatformEndpointOverrides(version, region, is_prediction=False): 43 """Context manager to override the AI Platform endpoints for a while. 44 45 Raises an error if 46 region is not set. 47 48 Args: 49 version: str, implies the version that the endpoint will use. 50 region: str, region of the AI Platform stack. 51 is_prediction: bool, it's for prediction endpoint or not. 52 53 Yields: 54 None 55 """ 56 used_endpoint = GetEffectiveEndpoint(version=version, region=region, 57 is_prediction=is_prediction) 58 log.status.Print('Using endpoint [{}]'.format(used_endpoint)) 59 properties.VALUES.api_endpoint_overrides.aiplatform.Set(used_endpoint) 60 yield 61 62 63def GetEffectiveEndpoint(version, region, is_prediction=False): 64 """Returns regional AI Platform endpoint, or raise an error if the region not set.""" 65 endpoint = apis.GetEffectiveApiEndpoint( 66 constants.AI_PLATFORM_API_NAME, 67 constants.AI_PLATFORM_API_VERSION[version]) 68 return DeriveAiplatformRegionalEndpoint( 69 endpoint, region, is_prediction=is_prediction) 70