1# -*- coding: utf-8 -*- #
2# Copyright 2018 Google LLC. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#    http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Library for integrating Cloud Run with GKE."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20from __future__ import unicode_literals
21
22import base64
23import contextlib
24import os
25import socket
26import ssl
27import tempfile
28import threading
29from googlecloudsdk.api_lib.container import api_adapter
30from googlecloudsdk.core import exceptions
31from googlecloudsdk.core import properties
32from googlecloudsdk.core.util import files
33
34import six
35
36
37class NoCaCertError(exceptions.Error):
38  pass
39
40
41class _AddressPatches(object):
42  """Singleton class to hold patches on getaddrinfo."""
43
44  _instance = None
45
46  @classmethod
47  def Initialize(cls):
48    assert not cls._instance
49    cls._instance = cls()
50
51  @classmethod
52  def Get(cls):
53    assert cls._instance
54    return cls._instance
55
56  def __init__(self):
57    self._host_to_ip = None
58    self._ip_to_host = None
59    self._old_getaddrinfo = None
60    self._old_match_hostname = None
61    self._lock = threading.Lock()
62
63  @contextlib.contextmanager
64  def MonkeypatchAddressChecking(self, hostname, ip):
65    """Change ssl address checking so the given ip answers to the hostname."""
66    with self._lock:
67      if self._host_to_ip is None:
68        self._host_to_ip = {}
69        self._ip_to_host = {}
70        self._old_match_hostname = ssl.match_hostname
71        self._old_getaddrinfo = socket.getaddrinfo
72        if six.PY3:
73          ssl.match_hostname = self._MatchHostname
74        else:
75          socket.getaddrinfo = self._GetAddrInfo
76      if hostname in self._host_to_ip:
77        raise ValueError(
78            'Cannot re-patch the same address: {}'.format(hostname))
79      if ip in self._ip_to_host:
80        raise ValueError(
81            'Cannot re-patch the same address: {}'.format(ip))
82      self._host_to_ip[hostname] = ip
83      self._ip_to_host[ip] = hostname
84    try:
85      if six.PY3:
86        yield ip
87      else:
88        yield hostname
89    finally:
90      with self._lock:
91        del self._host_to_ip[hostname]
92        del self._ip_to_host[ip]
93        if not self._host_to_ip:
94          self._host_to_ip = None
95          self._ip_to_host = None
96          if six.PY3:
97            ssl.match_hostname = self._old_match_hostname
98          else:
99            socket.getaddrinfo = self._old_getaddrinfo
100
101  def _GetAddrInfo(self, host, *args, **kwargs):
102    """Like socket.getaddrinfo, only with translation."""
103    with self._lock:
104      assert self._host_to_ip is not None
105      if host in self._host_to_ip:
106        host = self._host_to_ip[host]
107    return self._old_getaddrinfo(host, *args, **kwargs)
108
109  def _MatchHostname(self, cert, hostname):
110    # A replacement for ssl.match_hostname(cert, hostname)
111    # Since we'll be connecting with hostname as bare IP address, the goal is
112    # to treat that as if it were the hostname `kubernetes.default`, which
113    # is what the GKE master asserts it is.
114    with self._lock:
115      assert self._ip_to_host is not None
116      if hostname in self._ip_to_host:
117        hostname = self._ip_to_host[hostname]
118    return self._old_match_hostname(cert, hostname)
119
120_AddressPatches.Initialize()
121
122
123def MonkeypatchAddressChecking(hostname, ip):
124  """Manipulate SSL address checking so we can talk to GKE.
125
126  GKE provides an IP address for talking to the k8s master, and a
127  ca_certs that signs the tls certificate the master provides. Unfortunately,
128  that tls certificate is for `kubernetes`, `kubernetes.default`,
129  `kubernetes.default.svc`, or `kubernetes.default.svc.cluster.local`.
130
131  In Python 3, we do this by patching ssl.match_hostname to allow the
132  `kubernetes.default` when we connect to the given IP address.
133
134  In Python 2, httplib2 does its own hosname checking so this isn't available.
135  Instead, we change getaddrinfo to allow a "fake /etc/hosts" effect.
136  This allows us to use `kubernetes.default` as the hostname while still
137  connecting to the ip address we know is the kubernetes server.
138
139  This is all ok, because we got the ca_cert that it'll use directly from the
140  gke api.  Calls to `getaddrinfo` that specifically ask for a given hostname
141  can be redirected to the ip address we provide for the hostname, as if we had
142  edited /etc/hosts, without editing /etc/hosts.
143
144  Arguments:
145    hostname: hostname to replace
146    ip: ip address to replace the hostname with
147  Returns:
148    A context manager that patches an internal function for its duration, and
149    yields the endpoint to actually connect to.
150  """
151  return _AddressPatches.Get().MonkeypatchAddressChecking(hostname, ip)
152
153
154@contextlib.contextmanager
155def _DisableUserProjectQuota():
156  """Use legacy quota; required for the container surface.
157
158  Causes the http libraries (action at a distance, but there's no better way)
159  not to set the X-Goog-User-Project header.
160
161  Yields:
162    For a block to be executed while the user project quota is disabled.
163  """
164  reset_billing_project = False
165  try:
166    if not properties.VALUES.billing.quota_project.IsExplicitlySet():
167      reset_billing_project = True
168      properties.VALUES.billing.quota_project.Set(
169          properties.VALUES.billing.LEGACY)
170    yield
171  finally:
172    if reset_billing_project:
173      properties.VALUES.billing.quota_project.Set(None)
174
175
176@contextlib.contextmanager
177def ClusterConnectionInfo(cluster_ref):
178  """Get the info we need to use to connect to a GKE cluster.
179
180  Arguments:
181    cluster_ref: reference to the cluster to connect to.
182  Yields:
183    A tuple of (endpoint, ca_certs), where endpoint is the ip address
184    of the GKE master, and ca_certs is the absolute path of a temporary file
185    (lasting the life of the python process) holding the ca_certs to connect to
186    the GKE cluster.
187  Raises:
188    NoCaCertError: if the cluster is missing certificate authority data.
189  """
190  with _DisableUserProjectQuota():
191    adapter = api_adapter.NewAPIAdapter('v1')
192    cluster = adapter.GetCluster(cluster_ref)
193  auth = cluster.masterAuth
194  if auth and auth.clusterCaCertificate:
195    ca_data = auth.clusterCaCertificate
196  else:
197    # This should not happen unless the cluster is in an unusual error
198    # state.
199    raise NoCaCertError('Cluster is missing certificate authority data.')
200  fd, filename = tempfile.mkstemp()
201  os.close(fd)
202  files.WriteBinaryFileContents(filename, base64.b64decode(ca_data),
203                                private=True)
204  try:
205    yield cluster.endpoint, filename
206  finally:
207    os.remove(filename)
208
209
210