1# -*- coding: utf-8 -*- #
2# Copyright 2018 Google LLC. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#    http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Network endpoint group api client."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import unicode_literals
20
21from googlecloudsdk.api_lib.compute import utils as api_utils
22from googlecloudsdk.api_lib.compute.operations import poller
23from googlecloudsdk.api_lib.util import waiter
24from googlecloudsdk.command_lib.util.apis import arg_utils
25
26
27class NetworkEndpointGroupsClient(object):
28  """Client for network endpoint groups service in the GCE API."""
29
30  def __init__(self, client, messages, resources):
31    self.client = client
32    self.messages = messages
33    self.resources = resources
34    self._zonal_service = self.client.apitools_client.networkEndpointGroups
35    if hasattr(self.client.apitools_client, 'globalNetworkEndpointGroups'):
36      self._global_service = self.client.apitools_client.globalNetworkEndpointGroups
37    if hasattr(self.client.apitools_client, 'regionNetworkEndpointGroups'):
38      self._region_service = self.client.apitools_client.regionNetworkEndpointGroups
39
40  def Create(self,
41             neg_ref,
42             network_endpoint_type,
43             default_port=None,
44             network=None,
45             subnet=None,
46             cloud_run_service=None,
47             cloud_run_tag=None,
48             cloud_run_url_mask=None,
49             app_engine_app=False,
50             app_engine_service=None,
51             app_engine_version=None,
52             app_engine_url_mask=None,
53             cloud_function_name=None,
54             cloud_function_url_mask=None):
55    """Creates a network endpoint group."""
56    is_zonal = hasattr(neg_ref, 'zone')
57    is_regional = hasattr(neg_ref, 'region')
58
59    network_uri = None
60    if network and is_zonal:
61      network_ref = self.resources.Parse(network, {'project': neg_ref.project},
62                                         collection='compute.networks')
63      network_uri = network_ref.SelfLink()
64    subnet_uri = None
65    if subnet and is_zonal:
66      region = api_utils.ZoneNameToRegionName(neg_ref.zone)
67      subnet_ref = self.resources.Parse(
68          subnet,
69          {'project': neg_ref.project, 'region': region},
70          collection='compute.subnetworks')
71      subnet_uri = subnet_ref.SelfLink()
72
73    cloud_run = None
74    if cloud_run_service or cloud_run_tag or cloud_run_url_mask:
75      cloud_run = self.messages.NetworkEndpointGroupCloudRun(
76          service=cloud_run_service,
77          tag=cloud_run_tag,
78          urlMask=cloud_run_url_mask)
79    app_engine = None
80    if (app_engine_app or app_engine_service or app_engine_version or
81        app_engine_url_mask):
82      app_engine = self.messages.NetworkEndpointGroupAppEngine(
83          service=app_engine_service,
84          version=app_engine_version,
85          urlMask=app_engine_url_mask)
86    cloud_function = None
87    if cloud_function_name or cloud_function_url_mask:
88      cloud_function = self.messages.NetworkEndpointGroupCloudFunction(
89          function=cloud_function_name, urlMask=cloud_function_url_mask)
90
91    endpoint_type_enum = (self.messages.NetworkEndpointGroup
92                          .NetworkEndpointTypeValueValuesEnum)
93
94    # TODO(b/137663401): remove the check below after all Serverless flags go
95    # to GA.
96    if is_regional:
97      network_endpoint_group = self.messages.NetworkEndpointGroup(
98          name=neg_ref.Name(),
99          networkEndpointType=arg_utils.ChoiceToEnum(network_endpoint_type,
100                                                     endpoint_type_enum),
101          defaultPort=default_port,
102          network=network_uri,
103          subnetwork=subnet_uri,
104          cloudRun=cloud_run,
105          appEngine=app_engine,
106          cloudFunction=cloud_function)
107    else:
108      network_endpoint_group = self.messages.NetworkEndpointGroup(
109          name=neg_ref.Name(),
110          networkEndpointType=arg_utils.ChoiceToEnum(network_endpoint_type,
111                                                     endpoint_type_enum),
112          defaultPort=default_port,
113          network=network_uri,
114          subnetwork=subnet_uri)
115
116    if is_zonal:
117      request = self.messages.ComputeNetworkEndpointGroupsInsertRequest(
118          networkEndpointGroup=network_endpoint_group,
119          project=neg_ref.project,
120          zone=neg_ref.zone)
121      return self.client.MakeRequests([(self._zonal_service, 'Insert', request)
122                                      ])[0]
123    elif is_regional:
124      request = self.messages.ComputeRegionNetworkEndpointGroupsInsertRequest(
125          networkEndpointGroup=network_endpoint_group,
126          project=neg_ref.project,
127          region=neg_ref.region)
128      return self.client.MakeRequests([(self._region_service, 'Insert', request)
129                                      ])[0]
130    else:
131      request = self.messages.ComputeGlobalNetworkEndpointGroupsInsertRequest(
132          networkEndpointGroup=network_endpoint_group, project=neg_ref.project)
133      return self.client.MakeRequests([(self._global_service, 'Insert', request)
134                                      ])[0]
135
136  def _AttachZonalEndpoints(self, neg_ref, endpoints):
137    """Attaches network endpoints to a zonal network endpoint group."""
138    request_class = (
139        self.messages.ComputeNetworkEndpointGroupsAttachNetworkEndpointsRequest)
140    nested_request_class = (
141        self.messages.NetworkEndpointGroupsAttachEndpointsRequest)
142    request = request_class(
143        networkEndpointGroup=neg_ref.Name(),
144        project=neg_ref.project,
145        zone=neg_ref.zone,
146        networkEndpointGroupsAttachEndpointsRequest=nested_request_class(
147            networkEndpoints=self._GetEndpointMessageList(endpoints)))
148    return self._zonal_service.AttachNetworkEndpoints(request)
149
150  def _DetachZonalEndpoints(self, neg_ref, endpoints):
151    """Detaches network endpoints from a zonal network endpoint group."""
152    request_class = (
153        self.messages.ComputeNetworkEndpointGroupsDetachNetworkEndpointsRequest)
154    nested_request_class = (
155        self.messages.NetworkEndpointGroupsDetachEndpointsRequest)
156    request = request_class(
157        networkEndpointGroup=neg_ref.Name(),
158        project=neg_ref.project,
159        zone=neg_ref.zone,
160        networkEndpointGroupsDetachEndpointsRequest=nested_request_class(
161            networkEndpoints=self._GetEndpointMessageList(endpoints)))
162    return self._zonal_service.DetachNetworkEndpoints(request)
163
164  def _AttachGlobalEndpoints(self, neg_ref, endpoints):
165    """Attaches network endpoints to a global network endpoint group."""
166    request_class = (
167        self.messages
168        .ComputeGlobalNetworkEndpointGroupsAttachNetworkEndpointsRequest)
169    nested_request_class = (
170        self.messages.GlobalNetworkEndpointGroupsAttachEndpointsRequest)
171    request = request_class(
172        networkEndpointGroup=neg_ref.Name(),
173        project=neg_ref.project,
174        globalNetworkEndpointGroupsAttachEndpointsRequest=nested_request_class(
175            networkEndpoints=self._GetEndpointMessageList(endpoints)))
176    return self._global_service.AttachNetworkEndpoints(request)
177
178  def _DetachGlobalEndpoints(self, neg_ref, endpoints):
179    """Detaches network endpoints from a global network endpoint group."""
180    request_class = (
181        self.messages
182        .ComputeGlobalNetworkEndpointGroupsDetachNetworkEndpointsRequest)
183    nested_request_class = (
184        self.messages.GlobalNetworkEndpointGroupsDetachEndpointsRequest)
185    request = request_class(
186        networkEndpointGroup=neg_ref.Name(),
187        project=neg_ref.project,
188        globalNetworkEndpointGroupsDetachEndpointsRequest=nested_request_class(
189            networkEndpoints=self._GetEndpointMessageList(endpoints)))
190    return self._global_service.DetachNetworkEndpoints(request)
191
192  def _GetEndpointMessageList(self, endpoints):
193    """Convert endpoints to a list which can be passed in a request."""
194    output_list = []
195    for arg_endpoint in endpoints:
196      message_endpoint = self.messages.NetworkEndpoint()
197      if 'instance' in arg_endpoint:
198        message_endpoint.instance = arg_endpoint.get('instance')
199      if 'ip' in arg_endpoint:
200        message_endpoint.ipAddress = arg_endpoint.get('ip')
201      if 'port' in arg_endpoint:
202        message_endpoint.port = arg_endpoint.get('port')
203      if 'fqdn' in arg_endpoint:
204        message_endpoint.fqdn = arg_endpoint.get('fqdn')
205      output_list.append(message_endpoint)
206
207    return output_list
208
209  def _GetOperationsRef(self, operation):
210    return self.resources.Parse(operation.selfLink,
211                                collection='compute.zoneOperations')
212
213  def _GetGlobalOperationsRef(self, operation):
214    return self.resources.Parse(
215        operation.selfLink, collection='compute.globalOperations')
216
217  def _WaitForResult(self, operation_poller, operation_ref, message):
218    if operation_ref:
219      return waiter.WaitFor(operation_poller, operation_ref, message)
220    return None
221
222  def Update(self, neg_ref, add_endpoints=None, remove_endpoints=None):
223    """Updates a Compute Network Endpoint Group."""
224    attach_endpoints_ref = None
225    detach_endpoints_ref = None
226    operation_poller = None
227
228    if hasattr(neg_ref, 'zone'):
229      operation_poller = poller.Poller(self._zonal_service)
230      if add_endpoints:
231        operation = self._AttachZonalEndpoints(neg_ref, add_endpoints)
232        attach_endpoints_ref = self._GetOperationsRef(operation)
233      if remove_endpoints:
234        operation = self._DetachZonalEndpoints(neg_ref, remove_endpoints)
235        detach_endpoints_ref = self._GetOperationsRef(operation)
236    else:
237      operation_poller = poller.Poller(self._global_service)
238      if add_endpoints:
239        operation = self._AttachGlobalEndpoints(neg_ref, add_endpoints)
240        attach_endpoints_ref = self._GetGlobalOperationsRef(operation)
241      if remove_endpoints:
242        operation = self._DetachGlobalEndpoints(neg_ref, remove_endpoints)
243        detach_endpoints_ref = self._GetGlobalOperationsRef(operation)
244
245    neg_name = neg_ref.Name()
246    result = None
247    result = self._WaitForResult(
248        operation_poller, attach_endpoints_ref,
249        'Attaching {0} endpoints to [{1}].'.format(
250            len(add_endpoints) if add_endpoints else 0, neg_name)) or result
251    result = self._WaitForResult(
252        operation_poller, detach_endpoints_ref,
253        'Detaching {0} endpoints from [{1}].'.format(
254            len(remove_endpoints) if remove_endpoints else 0, neg_name)
255    ) or result
256
257    return result
258