1# -*- coding: utf-8 -*- # 2# Copyright 2020 Google LLC. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15"""Database Migration Service connection profiles API.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import unicode_literals 20 21from apitools.base.py import list_pager 22 23from googlecloudsdk.api_lib.database_migration import api_util 24from googlecloudsdk.calliope import exceptions 25from googlecloudsdk.command_lib.util.args import labels_util 26 27 28class ConnectionProfilesClient(object): 29 """Client for connection profiles service in the API.""" 30 31 def __init__(self, release_track): 32 self._api_version = api_util.GetApiVersion(release_track) 33 self.client = api_util.GetClientInstance(release_track) 34 self.messages = api_util.GetMessagesModule(release_track) 35 self._service = self.client.projects_locations_connectionProfiles 36 self.resource_parser = api_util.GetResourceParser(release_track) 37 38 def _ClientCertificateArgName(self): 39 if self._api_version == 'v1alpha2': 40 return 'certificate' 41 return 'client_certificate' 42 43 def _InstanceArgName(self): 44 if self._api_version == 'v1alpha2': 45 return 'instance' 46 return 'cloudsql_instance' 47 48 def _ValidateArgs(self, args): 49 self._ValidateSslConfigArgs(args) 50 51 def _ValidateSslConfigArgs(self, args): 52 self._ValidateCertificateFormat(args, 'ca_certificate') 53 self._ValidateCertificateFormat(args, self._ClientCertificateArgName()) 54 self._ValidateCertificateFormat(args, 'private_key') 55 56 def _ValidateCertificateFormat(self, args, field): 57 if not hasattr(args, field) or not args.IsSpecified(field): 58 return True 59 certificate = getattr(args, field) 60 cert = certificate.strip() 61 cert_lines = cert.split('\n') 62 if (not cert_lines[0].startswith('-----') 63 or not cert_lines[-1].startswith('-----')): 64 raise exceptions.InvalidArgumentException( 65 field, 66 'The certificate does not appear to be in PEM format:\n{0}' 67 .format(cert)) 68 69 def _GetSslConfig(self, args): 70 return self.messages.SslConfig( 71 clientKey=args.private_key, 72 clientCertificate=args.GetValue(self._ClientCertificateArgName()), 73 caCertificate=args.ca_certificate) 74 75 def _UpdateSslConfig(self, connection_profile, args, update_fields): 76 """Fills connection_profile and update_fields with SSL data from args.""" 77 if args.IsSpecified('ca_certificate'): 78 connection_profile.mysql.ssl.caCertificate = args.ca_certificate 79 update_fields.append('mysql.ssl.caCertificate') 80 if args.IsSpecified('private_key'): 81 connection_profile.mysql.ssl.clientKey = args.private_key 82 update_fields.append('mysql.ssl.clientKey') 83 if args.IsSpecified(self._ClientCertificateArgName()): 84 connection_profile.mysql.ssl.clientCertificate = args.GetValue( 85 self._ClientCertificateArgName()) 86 update_fields.append('mysql.ssl.clientCertificate') 87 88 def _GetMySqlConnectionProfile(self, args): 89 ssl_config = self._GetSslConfig(args) 90 return self.messages.MySqlConnectionProfile( 91 host=args.host, 92 port=args.port, 93 username=args.username, 94 password=args.password, 95 ssl=ssl_config, 96 cloudSqlId=args.GetValue(self._InstanceArgName())) 97 98 def _UpdateMySqlConnectionProfile( 99 self, connection_profile, args, update_fields): 100 """Updates MySQL connection profile.""" 101 if args.IsSpecified('host'): 102 connection_profile.mysql.host = args.host 103 update_fields.append('mysql.host') 104 if args.IsSpecified('port'): 105 connection_profile.mysql.port = args.port 106 update_fields.append('mysql.port') 107 if args.IsSpecified('username'): 108 connection_profile.mysql.username = args.username 109 update_fields.append('mysql.username') 110 if args.IsSpecified('password'): 111 connection_profile.mysql.password = args.password 112 update_fields.append('mysql.password') 113 if args.IsSpecified(self._InstanceArgName()): 114 connection_profile.mysql.cloudSqlId = args.GetValue( 115 self._InstanceArgName()) 116 update_fields.append('mysql.instance') 117 self._UpdateSslConfig(connection_profile, args, update_fields) 118 119 def _GetProvider(self, cp_type, provider): 120 if provider is None: 121 return cp_type.ProviderValueValuesEnum.DATABASE_PROVIDER_UNSPECIFIED 122 return cp_type.ProviderValueValuesEnum.lookup_by_name(provider) 123 124 def _GetActivationPolicy(self, cp_type, policy): 125 if policy is None: 126 return cp_type.ActivationPolicyValueValuesEnum.SQL_ACTIVATION_POLICY_UNSPECIFIED 127 return cp_type.ActivationPolicyValueValuesEnum.lookup_by_name(policy) 128 129 def _GetDatabaseVersion(self, cp_type, version): 130 return cp_type.DatabaseVersionValueValuesEnum.lookup_by_name(version) 131 132 def _GetAuthorizedNetworks(self, networks): 133 acl_entry = self.messages.SqlAclEntry 134 return [ 135 acl_entry(value=network) 136 for network in networks 137 ] 138 139 def _GetIpConfig(self, args): 140 return self.messages.SqlIpConfig( 141 enableIpv4=args.enable_ip_v4, 142 privateNetwork=args.private_network, 143 requireSsl=args.require_ssl, 144 authorizedNetworks=self._GetAuthorizedNetworks(args.authorized_networks) 145 ) 146 147 def _GetDataDiskType(self, cp_type, data_disk_type): 148 if data_disk_type is None: 149 return cp_type.DataDiskTypeValueValuesEnum.SQL_DATA_DISK_TYPE_UNSPECIFIED 150 return cp_type.DataDiskTypeValueValuesEnum.lookup_by_name(data_disk_type) 151 152 def _GetCloudSqlSettings(self, args): 153 cp_type = self.messages.CloudSqlSettings 154 source_id = args.CONCEPTS.source_id.Parse().RelativeName() 155 user_labels_value = labels_util.ParseCreateArgs( 156 args, cp_type.UserLabelsValue, 'user_labels') 157 database_flags = labels_util.ParseCreateArgs( 158 args, cp_type.DatabaseFlagsValue, 'database_flags') 159 return self.messages.CloudSqlSettings( 160 databaseVersion=self._GetDatabaseVersion( 161 cp_type, args.database_version), 162 userLabels=user_labels_value, 163 tier=args.tier, 164 storageAutoResizeLimit=args.storage_auto_resize_limit, 165 activationPolicy=self._GetActivationPolicy( 166 cp_type, args.activation_policy), 167 ipConfig=self._GetIpConfig(args), 168 autoStorageIncrease=args.auto_storage_increase, 169 databaseFlags=database_flags, 170 dataDiskType=self._GetDataDiskType(cp_type, args.data_disk_type), 171 dataDiskSizeGb=args.data_disk_size, 172 zone=args.zone, 173 sourceId=source_id 174 ) 175 176 def _GetCloudSqlConnectionProfile(self, args): 177 settings = self._GetCloudSqlSettings(args) 178 return self.messages.CloudSqlConnectionProfile(settings=settings) 179 180 def _GetConnectionProfile(self, cp_type, connection_profile_id, args): 181 """Returns a connection profile according to type.""" 182 connection_profile_type = self.messages.ConnectionProfile 183 provider = self._GetProvider(connection_profile_type, args.provider) 184 labels = labels_util.ParseCreateArgs( 185 args, connection_profile_type.LabelsValue) 186 params = {} 187 if cp_type == 'MYSQL': 188 mysql_connection_profile = self._GetMySqlConnectionProfile(args) 189 params['mysql'] = mysql_connection_profile 190 elif cp_type == 'CLOUDSQL': 191 cloudsql_connection_profile = self._GetCloudSqlConnectionProfile(args) 192 params['cloudsql'] = cloudsql_connection_profile 193 return connection_profile_type( 194 name=connection_profile_id, 195 labels=labels, 196 state=connection_profile_type.StateValueValuesEnum.CREATING, 197 displayName=args.display_name, 198 provider=provider, 199 **params) 200 201 def _GetExistingConnectionProfile(self, name): 202 get_req = self.messages.DatamigrationProjectsLocationsConnectionProfilesGetRequest( 203 name=name 204 ) 205 return self._service.Get(get_req) 206 207 def _UpdateLabels(self, connection_profile, args): 208 """Updates labels of the connection profile.""" 209 add_labels = labels_util.GetUpdateLabelsDictFromArgs(args) 210 remove_labels = labels_util.GetRemoveLabelsListFromArgs(args) 211 value_type = self.messages.ConnectionProfile.LabelsValue 212 update_result = labels_util.Diff( 213 additions=add_labels, 214 subtractions=remove_labels, 215 clear=args.clear_labels 216 ).Apply(value_type, connection_profile.labels) 217 if update_result.needs_update: 218 connection_profile.labels = update_result.labels 219 220 # TODO(b/177659340): Avoid raising python build-in exceptions. 221 def _GetUpdatedConnectionProfile(self, connection_profile, args): 222 """Returns updated connection profile and list of updated fields.""" 223 update_fields = [] 224 if args.IsSpecified('display_name'): 225 connection_profile.displayName = args.display_name 226 update_fields.append('displayName') 227 if connection_profile.mysql is not None: 228 self._UpdateMySqlConnectionProfile(connection_profile, 229 args, 230 update_fields) 231 else: 232 raise AttributeError( 233 'The connection profile requested does not contain mysql object. ' 234 'Currently only mysql connection profile is supported.') 235 self._UpdateLabels(connection_profile, args) 236 return connection_profile, update_fields 237 238 def Create(self, parent_ref, connection_profile_id, cp_type, args=None): 239 """Creates a connection profile. 240 241 Args: 242 parent_ref: a Resource reference to a parent 243 datamigration.projects.locations resource for this connection 244 profile. 245 connection_profile_id: str, the name of the resource to create. 246 cp_type: str, the type of the connection profile ('MYSQL', '' 247 args: argparse.Namespace, The arguments that this command was 248 invoked with. 249 250 Returns: 251 Operation: the operation for creating the connection profile. 252 """ 253 self._ValidateArgs(args) 254 255 connection_profile = self._GetConnectionProfile(cp_type, 256 connection_profile_id, args) 257 258 request_id = api_util.GenerateRequestId() 259 create_req_type = self.messages.DatamigrationProjectsLocationsConnectionProfilesCreateRequest 260 create_req = create_req_type( 261 connectionProfile=connection_profile, 262 connectionProfileId=connection_profile.name, 263 parent=parent_ref, 264 requestId=request_id 265 ) 266 267 return self._service.Create(create_req) 268 269 def Update(self, name, args=None): 270 """Updates a connection profile. 271 272 Args: 273 name: str, the reference of the connection profile to 274 update. 275 args: argparse.Namespace, The arguments that this command was 276 invoked with. 277 278 Returns: 279 Operation: the operation for updating the connection profile. 280 """ 281 self._ValidateArgs(args) 282 283 current_cp = self._GetExistingConnectionProfile(name) 284 285 updated_cp, update_fields = self._GetUpdatedConnectionProfile( 286 current_cp, args) 287 288 request_id = api_util.GenerateRequestId() 289 update_req_type = self.messages.DatamigrationProjectsLocationsConnectionProfilesPatchRequest 290 update_req = update_req_type( 291 connectionProfile=updated_cp, 292 name=updated_cp.name, 293 updateMask=','.join(update_fields), 294 requestId=request_id 295 ) 296 297 return self._service.Patch(update_req) 298 299 def List(self, project_id, args): 300 """Get the list of connection profiles in a project. 301 302 Args: 303 project_id: The project ID to retrieve 304 args: parsed command line arguments 305 306 Returns: 307 An iterator over all the matching connection profiles. 308 """ 309 location_ref = self.resource_parser.Create( 310 'datamigration.projects.locations', 311 projectsId=project_id, 312 locationsId=args.region) 313 314 list_req_type = self.messages.DatamigrationProjectsLocationsConnectionProfilesListRequest 315 list_req = list_req_type( 316 parent=location_ref.RelativeName(), 317 filter=args.filter, 318 orderBy=','.join(args.sort_by) if args.sort_by else None) 319 320 return list_pager.YieldFromList( 321 service=self.client.projects_locations_connectionProfiles, 322 request=list_req, 323 limit=args.limit, 324 batch_size=args.page_size, 325 field='connectionProfiles', 326 batch_size_attribute='pageSize') 327 328 def GetUri(self, name): 329 """Get the URL string for a connnection profile. 330 331 Args: 332 name: connection profile's full name. 333 334 Returns: 335 URL of the connection profile resource 336 """ 337 338 uri = self.resource_parser.ParseRelativeName( 339 name, 340 collection='datamigration.projects.locations.connectionProfiles') 341 return uri.SelfLink() 342 343