1# -*- coding: utf-8 -*- # 2# Copyright 2020 Google LLC. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15"""Utilities for operating on different endpoints.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import unicode_literals 20 21import contextlib 22 23from googlecloudsdk.api_lib.util import apis 24from googlecloudsdk.core import log 25from googlecloudsdk.core import properties 26from six.moves.urllib import parse 27 28ML_API_VERSION = 'v1' 29ML_API_NAME = 'ml' 30 31 32def DeriveMLRegionalEndpoint(endpoint, region): 33 scheme, netloc, path, params, query, fragment = parse.urlparse(endpoint) 34 netloc = '{}-{}'.format(region, netloc) 35 return parse.urlunparse((scheme, netloc, path, params, query, fragment)) 36 37 38@contextlib.contextmanager 39def MlEndpointOverrides(region=None): 40 """Context manager to override the AI Platform endpoints for a while. 41 42 Args: 43 region: str, region of the AI Platform stack. 44 45 Yields: 46 None. 47 """ 48 used_endpoint = GetEffectiveMlEndpoint(region) 49 old_endpoint = properties.VALUES.api_endpoint_overrides.ml.Get() 50 try: 51 log.status.Print('Using endpoint [{}]'.format(used_endpoint)) 52 if region and region != 'global': 53 properties.VALUES.api_endpoint_overrides.ml.Set(used_endpoint) 54 yield 55 finally: 56 old_endpoint = properties.VALUES.api_endpoint_overrides.ml.Set(old_endpoint) 57 58 59def GetEffectiveMlEndpoint(region): 60 """Returns regional ML Endpoint, or global if region not set.""" 61 endpoint = apis.GetEffectiveApiEndpoint(ML_API_NAME, ML_API_VERSION) 62 if region and region != 'global': 63 return DeriveMLRegionalEndpoint(endpoint, region) 64 return endpoint 65