1# -*- coding: utf-8 -*- #
2# Copyright 2020 Google LLC. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#    http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Utilities for operating on different endpoints."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import unicode_literals
20
21import contextlib
22
23from googlecloudsdk.api_lib.util import apis
24from googlecloudsdk.core import log
25from googlecloudsdk.core import properties
26from six.moves.urllib import parse
27
28ML_API_VERSION = 'v1'
29ML_API_NAME = 'ml'
30
31
32def DeriveMLRegionalEndpoint(endpoint, region):
33  scheme, netloc, path, params, query, fragment = parse.urlparse(endpoint)
34  netloc = '{}-{}'.format(region, netloc)
35  return parse.urlunparse((scheme, netloc, path, params, query, fragment))
36
37
38@contextlib.contextmanager
39def MlEndpointOverrides(region=None):
40  """Context manager to override the AI Platform endpoints for a while.
41
42  Args:
43    region: str, region of the AI Platform stack.
44
45  Yields:
46    None.
47  """
48  used_endpoint = GetEffectiveMlEndpoint(region)
49  old_endpoint = properties.VALUES.api_endpoint_overrides.ml.Get()
50  try:
51    log.status.Print('Using endpoint [{}]'.format(used_endpoint))
52    if region and region != 'global':
53      properties.VALUES.api_endpoint_overrides.ml.Set(used_endpoint)
54    yield
55  finally:
56    old_endpoint = properties.VALUES.api_endpoint_overrides.ml.Set(old_endpoint)
57
58
59def GetEffectiveMlEndpoint(region):
60  """Returns regional ML Endpoint, or global if region not set."""
61  endpoint = apis.GetEffectiveApiEndpoint(ML_API_NAME, ML_API_VERSION)
62  if region and region != 'global':
63    return DeriveMLRegionalEndpoint(endpoint, region)
64  return endpoint
65