1# -*- coding: utf-8 -*- #
2# Copyright 2020 Google LLC. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#    http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""Database Migration Service connection profiles API."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import unicode_literals
20
21from apitools.base.py import list_pager
22
23from googlecloudsdk.api_lib.database_migration import api_util
24from googlecloudsdk.calliope import exceptions
25from googlecloudsdk.command_lib.util.args import labels_util
26
27
28class ConnectionProfilesClient(object):
29  """Client for connection profiles service in the API."""
30
31  def __init__(self, release_track):
32    self._api_version = api_util.GetApiVersion(release_track)
33    self.client = api_util.GetClientInstance(release_track)
34    self.messages = api_util.GetMessagesModule(release_track)
35    self._service = self.client.projects_locations_connectionProfiles
36    self.resource_parser = api_util.GetResourceParser(release_track)
37
38  def _ClientCertificateArgName(self):
39    if self._api_version == 'v1alpha2':
40      return 'certificate'
41    return 'client_certificate'
42
43  def _InstanceArgName(self):
44    if self._api_version == 'v1alpha2':
45      return 'instance'
46    return 'cloudsql_instance'
47
48  def _ValidateArgs(self, args):
49    self._ValidateSslConfigArgs(args)
50
51  def _ValidateSslConfigArgs(self, args):
52    self._ValidateCertificateFormat(args, 'ca_certificate')
53    self._ValidateCertificateFormat(args, self._ClientCertificateArgName())
54    self._ValidateCertificateFormat(args, 'private_key')
55
56  def _ValidateCertificateFormat(self, args, field):
57    if not hasattr(args, field) or not args.IsSpecified(field):
58      return True
59    certificate = getattr(args, field)
60    cert = certificate.strip()
61    cert_lines = cert.split('\n')
62    if (not cert_lines[0].startswith('-----')
63        or not cert_lines[-1].startswith('-----')):
64      raise exceptions.InvalidArgumentException(
65          field,
66          'The certificate does not appear to be in PEM format:\n{0}'
67          .format(cert))
68
69  def _GetSslConfig(self, args):
70    return self.messages.SslConfig(
71        clientKey=args.private_key,
72        clientCertificate=args.GetValue(self._ClientCertificateArgName()),
73        caCertificate=args.ca_certificate)
74
75  def _UpdateSslConfig(self, connection_profile, args, update_fields):
76    """Fills connection_profile and update_fields with SSL data from args."""
77    if args.IsSpecified('ca_certificate'):
78      connection_profile.mysql.ssl.caCertificate = args.ca_certificate
79      update_fields.append('mysql.ssl.caCertificate')
80    if args.IsSpecified('private_key'):
81      connection_profile.mysql.ssl.clientKey = args.private_key
82      update_fields.append('mysql.ssl.clientKey')
83    if args.IsSpecified(self._ClientCertificateArgName()):
84      connection_profile.mysql.ssl.clientCertificate = args.GetValue(
85          self._ClientCertificateArgName())
86      update_fields.append('mysql.ssl.clientCertificate')
87
88  def _GetMySqlConnectionProfile(self, args):
89    ssl_config = self._GetSslConfig(args)
90    return self.messages.MySqlConnectionProfile(
91        host=args.host,
92        port=args.port,
93        username=args.username,
94        password=args.password,
95        ssl=ssl_config,
96        cloudSqlId=args.GetValue(self._InstanceArgName()))
97
98  def _UpdateMySqlConnectionProfile(
99      self, connection_profile, args, update_fields):
100    """Updates MySQL connection profile."""
101    if args.IsSpecified('host'):
102      connection_profile.mysql.host = args.host
103      update_fields.append('mysql.host')
104    if args.IsSpecified('port'):
105      connection_profile.mysql.port = args.port
106      update_fields.append('mysql.port')
107    if args.IsSpecified('username'):
108      connection_profile.mysql.username = args.username
109      update_fields.append('mysql.username')
110    if args.IsSpecified('password'):
111      connection_profile.mysql.password = args.password
112      update_fields.append('mysql.password')
113    if args.IsSpecified(self._InstanceArgName()):
114      connection_profile.mysql.cloudSqlId = args.GetValue(
115          self._InstanceArgName())
116      update_fields.append('mysql.instance')
117    self._UpdateSslConfig(connection_profile, args, update_fields)
118
119  def _GetProvider(self, cp_type, provider):
120    if provider is None:
121      return cp_type.ProviderValueValuesEnum.DATABASE_PROVIDER_UNSPECIFIED
122    return cp_type.ProviderValueValuesEnum.lookup_by_name(provider)
123
124  def _GetActivationPolicy(self, cp_type, policy):
125    if policy is None:
126      return cp_type.ActivationPolicyValueValuesEnum.SQL_ACTIVATION_POLICY_UNSPECIFIED
127    return cp_type.ActivationPolicyValueValuesEnum.lookup_by_name(policy)
128
129  def _GetDatabaseVersion(self, cp_type, version):
130    return cp_type.DatabaseVersionValueValuesEnum.lookup_by_name(version)
131
132  def _GetAuthorizedNetworks(self, networks):
133    acl_entry = self.messages.SqlAclEntry
134    return [
135        acl_entry(value=network)
136        for network in networks
137    ]
138
139  def _GetIpConfig(self, args):
140    return self.messages.SqlIpConfig(
141        enableIpv4=args.enable_ip_v4,
142        privateNetwork=args.private_network,
143        requireSsl=args.require_ssl,
144        authorizedNetworks=self._GetAuthorizedNetworks(args.authorized_networks)
145    )
146
147  def _GetDataDiskType(self, cp_type, data_disk_type):
148    if data_disk_type is None:
149      return  cp_type.DataDiskTypeValueValuesEnum.SQL_DATA_DISK_TYPE_UNSPECIFIED
150    return cp_type.DataDiskTypeValueValuesEnum.lookup_by_name(data_disk_type)
151
152  def _GetCloudSqlSettings(self, args):
153    cp_type = self.messages.CloudSqlSettings
154    source_id = args.CONCEPTS.source_id.Parse().RelativeName()
155    user_labels_value = labels_util.ParseCreateArgs(
156        args, cp_type.UserLabelsValue, 'user_labels')
157    database_flags = labels_util.ParseCreateArgs(
158        args, cp_type.DatabaseFlagsValue, 'database_flags')
159    return self.messages.CloudSqlSettings(
160        databaseVersion=self._GetDatabaseVersion(
161            cp_type, args.database_version),
162        userLabels=user_labels_value,
163        tier=args.tier,
164        storageAutoResizeLimit=args.storage_auto_resize_limit,
165        activationPolicy=self._GetActivationPolicy(
166            cp_type, args.activation_policy),
167        ipConfig=self._GetIpConfig(args),
168        autoStorageIncrease=args.auto_storage_increase,
169        databaseFlags=database_flags,
170        dataDiskType=self._GetDataDiskType(cp_type, args.data_disk_type),
171        dataDiskSizeGb=args.data_disk_size,
172        zone=args.zone,
173        sourceId=source_id
174    )
175
176  def _GetCloudSqlConnectionProfile(self, args):
177    settings = self._GetCloudSqlSettings(args)
178    return self.messages.CloudSqlConnectionProfile(settings=settings)
179
180  def _GetConnectionProfile(self, cp_type, connection_profile_id, args):
181    """Returns a connection profile according to type."""
182    connection_profile_type = self.messages.ConnectionProfile
183    provider = self._GetProvider(connection_profile_type, args.provider)
184    labels = labels_util.ParseCreateArgs(
185        args, connection_profile_type.LabelsValue)
186    params = {}
187    if cp_type == 'MYSQL':
188      mysql_connection_profile = self._GetMySqlConnectionProfile(args)
189      params['mysql'] = mysql_connection_profile
190    elif cp_type == 'CLOUDSQL':
191      cloudsql_connection_profile = self._GetCloudSqlConnectionProfile(args)
192      params['cloudsql'] = cloudsql_connection_profile
193    return connection_profile_type(
194        name=connection_profile_id,
195        labels=labels,
196        state=connection_profile_type.StateValueValuesEnum.CREATING,
197        displayName=args.display_name,
198        provider=provider,
199        **params)
200
201  def _GetExistingConnectionProfile(self, name):
202    get_req = self.messages.DatamigrationProjectsLocationsConnectionProfilesGetRequest(
203        name=name
204    )
205    return self._service.Get(get_req)
206
207  def _UpdateLabels(self, connection_profile, args):
208    """Updates labels of the connection profile."""
209    add_labels = labels_util.GetUpdateLabelsDictFromArgs(args)
210    remove_labels = labels_util.GetRemoveLabelsListFromArgs(args)
211    value_type = self.messages.ConnectionProfile.LabelsValue
212    update_result = labels_util.Diff(
213        additions=add_labels,
214        subtractions=remove_labels,
215        clear=args.clear_labels
216    ).Apply(value_type, connection_profile.labels)
217    if update_result.needs_update:
218      connection_profile.labels = update_result.labels
219
220  # TODO(b/177659340): Avoid raising python build-in exceptions.
221  def _GetUpdatedConnectionProfile(self, connection_profile, args):
222    """Returns updated connection profile and list of updated fields."""
223    update_fields = []
224    if args.IsSpecified('display_name'):
225      connection_profile.displayName = args.display_name
226      update_fields.append('displayName')
227    if connection_profile.mysql is not None:
228      self._UpdateMySqlConnectionProfile(connection_profile,
229                                         args,
230                                         update_fields)
231    else:
232      raise AttributeError(
233          'The connection profile requested does not contain mysql object. '
234          'Currently only mysql connection profile is supported.')
235    self._UpdateLabels(connection_profile, args)
236    return connection_profile, update_fields
237
238  def Create(self, parent_ref, connection_profile_id, cp_type, args=None):
239    """Creates a connection profile.
240
241    Args:
242      parent_ref: a Resource reference to a parent
243        datamigration.projects.locations resource for this connection
244        profile.
245      connection_profile_id: str, the name of the resource to create.
246      cp_type: str, the type of the connection profile ('MYSQL', ''
247      args: argparse.Namespace, The arguments that this command was
248          invoked with.
249
250    Returns:
251      Operation: the operation for creating the connection profile.
252    """
253    self._ValidateArgs(args)
254
255    connection_profile = self._GetConnectionProfile(cp_type,
256                                                    connection_profile_id, args)
257
258    request_id = api_util.GenerateRequestId()
259    create_req_type = self.messages.DatamigrationProjectsLocationsConnectionProfilesCreateRequest
260    create_req = create_req_type(
261        connectionProfile=connection_profile,
262        connectionProfileId=connection_profile.name,
263        parent=parent_ref,
264        requestId=request_id
265        )
266
267    return self._service.Create(create_req)
268
269  def Update(self, name, args=None):
270    """Updates a connection profile.
271
272    Args:
273      name: str, the reference of the connection profile to
274          update.
275      args: argparse.Namespace, The arguments that this command was
276          invoked with.
277
278    Returns:
279      Operation: the operation for updating the connection profile.
280    """
281    self._ValidateArgs(args)
282
283    current_cp = self._GetExistingConnectionProfile(name)
284
285    updated_cp, update_fields = self._GetUpdatedConnectionProfile(
286        current_cp, args)
287
288    request_id = api_util.GenerateRequestId()
289    update_req_type = self.messages.DatamigrationProjectsLocationsConnectionProfilesPatchRequest
290    update_req = update_req_type(
291        connectionProfile=updated_cp,
292        name=updated_cp.name,
293        updateMask=','.join(update_fields),
294        requestId=request_id
295    )
296
297    return self._service.Patch(update_req)
298
299  def List(self, project_id, args):
300    """Get the list of connection profiles in a project.
301
302    Args:
303      project_id: The project ID to retrieve
304      args: parsed command line arguments
305
306    Returns:
307      An iterator over all the matching connection profiles.
308    """
309    location_ref = self.resource_parser.Create(
310        'datamigration.projects.locations',
311        projectsId=project_id,
312        locationsId=args.region)
313
314    list_req_type = self.messages.DatamigrationProjectsLocationsConnectionProfilesListRequest
315    list_req = list_req_type(
316        parent=location_ref.RelativeName(),
317        filter=args.filter,
318        orderBy=','.join(args.sort_by) if args.sort_by else None)
319
320    return list_pager.YieldFromList(
321        service=self.client.projects_locations_connectionProfiles,
322        request=list_req,
323        limit=args.limit,
324        batch_size=args.page_size,
325        field='connectionProfiles',
326        batch_size_attribute='pageSize')
327
328  def GetUri(self, name):
329    """Get the URL string for a connnection profile.
330
331    Args:
332      name: connection profile's full name.
333
334    Returns:
335      URL of the connection profile resource
336    """
337
338    uri = self.resource_parser.ParseRelativeName(
339        name,
340        collection='datamigration.projects.locations.connectionProfiles')
341    return uri.SelfLink()
342
343