1# -------------------------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for license information. 4# -------------------------------------------------------------------------------------------- 5 6# pylint: disable=unused-argument, line-too-long 7from datetime import datetime, timedelta 8from importlib import import_module 9import re 10from dateutil.tz import tzutc # pylint: disable=import-error 11from msrestazure.azure_exceptions import CloudError 12from msrestazure.tools import resource_id, is_valid_resource_id, parse_resource_id # pylint: disable=import-error 13from knack.log import get_logger 14from knack.util import todict 15from six.moves.urllib.request import urlretrieve # pylint: disable=import-error 16from azure.core.exceptions import ResourceNotFoundError 17from azure.cli.core._profile import Profile 18from azure.cli.core.commands.client_factory import get_subscription_id 19from azure.cli.core.util import CLIError, sdk_no_wait 20from azure.cli.core.local_context import ALL 21from azure.mgmt.rdbms import postgresql, mysql, mariadb 22from azure.mgmt.rdbms.mysql.operations._servers_operations import ServersOperations as MySqlServersOperations 23from azure.mgmt.rdbms.postgresql.operations._location_based_performance_tier_operations import LocationBasedPerformanceTierOperations as PostgreSQLLocationOperations 24from azure.mgmt.rdbms.mariadb.operations._servers_operations import ServersOperations as MariaDBServersOperations 25from azure.mgmt.rdbms.mariadb.operations._location_based_performance_tier_operations import LocationBasedPerformanceTierOperations as MariaDBLocationOperations 26from ._client_factory import get_mariadb_management_client, get_mysql_management_client, cf_mysql_db, cf_mariadb_db, \ 27 get_postgresql_management_client, cf_postgres_check_resource_availability_sterling, \ 28 cf_mysql_check_resource_availability_sterling, cf_mariadb_check_resource_availability_sterling 29from ._flexible_server_util import generate_missing_parameters, generate_password, resolve_poller 30from ._util import parse_public_network_access_input, create_firewall_rule 31 32logger = get_logger(__name__) 33 34 35SKU_TIER_MAP = {'Basic': 'b', 'GeneralPurpose': 'gp', 'MemoryOptimized': 'mo'} 36DEFAULT_DB_NAME = 'defaultdb' 37 38 39# pylint: disable=too-many-locals, too-many-statements, raise-missing-from 40def _server_create(cmd, client, resource_group_name=None, server_name=None, sku_name=None, no_wait=False, 41 location=None, administrator_login=None, administrator_login_password=None, backup_retention=None, 42 geo_redundant_backup=None, ssl_enforcement=None, storage_mb=None, tags=None, version=None, auto_grow='Enabled', 43 assign_identity=False, public_network_access=None, infrastructure_encryption=None, minimal_tls_version=None): 44 provider = 'Microsoft.DBforPostgreSQL' 45 if isinstance(client, MySqlServersOperations): 46 provider = 'Microsoft.DBforMySQL' 47 elif isinstance(client, MariaDBServersOperations): 48 provider = 'Microsoft.DBforMariaDB' 49 50 server_result = firewall_id = None 51 administrator_login_password = generate_password(administrator_login_password) 52 engine_name = 'postgres' 53 pricing_link = 'https://aka.ms/postgres-pricing' 54 start_ip = end_ip = '' 55 56 if public_network_access is not None and str(public_network_access).lower() != 'enabled' and str(public_network_access).lower() != 'disabled': 57 if str(public_network_access).lower() == 'all': 58 start_ip, end_ip = '0.0.0.0', '255.255.255.255' 59 else: 60 start_ip, end_ip = parse_public_network_access_input(public_network_access) 61 # if anything but 'disabled' is passed on to the args, 62 # then the public network access value passed on to the API is Enabled. 63 public_network_access = 'Enabled' 64 65 # Check availability for server name if it is supplied by the user 66 if provider == 'Microsoft.DBforPostgreSQL': 67 # Populate desired parameters 68 location, resource_group_name, server_name = generate_missing_parameters(cmd, location, resource_group_name, 69 server_name, engine_name) 70 check_name_client = cf_postgres_check_resource_availability_sterling(cmd.cli_ctx, None) 71 name_availability_resquest = postgresql.models.NameAvailabilityRequest(name=server_name, type="Microsoft.DBforPostgreSQL/servers") 72 check_server_name_availability(check_name_client, name_availability_resquest) 73 logger.warning('Creating %s Server \'%s\' in group \'%s\'...', engine_name, server_name, resource_group_name) 74 logger.warning('Your server \'%s\' is using sku \'%s\' (Paid Tier). ' 75 'Please refer to %s for pricing details', server_name, sku_name, pricing_link) 76 parameters = postgresql.models.ServerForCreate( 77 sku=postgresql.models.Sku(name=sku_name), 78 properties=postgresql.models.ServerPropertiesForDefaultCreate( 79 administrator_login=administrator_login, 80 administrator_login_password=administrator_login_password, 81 version=version, 82 ssl_enforcement=ssl_enforcement, 83 minimal_tls_version=minimal_tls_version, 84 public_network_access=public_network_access, 85 infrastructure_encryption=infrastructure_encryption, 86 storage_profile=postgresql.models.StorageProfile( 87 backup_retention_days=backup_retention, 88 geo_redundant_backup=geo_redundant_backup, 89 storage_mb=storage_mb, 90 storage_autogrow=auto_grow)), 91 location=location, 92 tags=tags) 93 if assign_identity: 94 parameters.identity = postgresql.models.ResourceIdentity( 95 type=postgresql.models.IdentityType.system_assigned.value) 96 elif provider == 'Microsoft.DBforMySQL': 97 engine_name = 'mysql' 98 pricing_link = 'https://aka.ms/mysql-pricing' 99 location, resource_group_name, server_name = generate_missing_parameters(cmd, location, resource_group_name, 100 server_name, engine_name) 101 check_name_client = cf_mysql_check_resource_availability_sterling(cmd.cli_ctx, None) 102 name_availability_resquest = mysql.models.NameAvailabilityRequest(name=server_name, type="Microsoft.DBforMySQL/servers") 103 check_server_name_availability(check_name_client, name_availability_resquest) 104 logger.warning('Creating %s Server \'%s\' in group \'%s\'...', engine_name, server_name, resource_group_name) 105 logger.warning('Your server \'%s\' is using sku \'%s\' (Paid Tier). ' 106 'Please refer to %s for pricing details', server_name, sku_name, pricing_link) 107 parameters = mysql.models.ServerForCreate( 108 sku=mysql.models.Sku(name=sku_name), 109 properties=mysql.models.ServerPropertiesForDefaultCreate( 110 administrator_login=administrator_login, 111 administrator_login_password=administrator_login_password, 112 version=version, 113 ssl_enforcement=ssl_enforcement, 114 minimal_tls_version=minimal_tls_version, 115 public_network_access=public_network_access, 116 infrastructure_encryption=infrastructure_encryption, 117 storage_profile=mysql.models.StorageProfile( 118 backup_retention_days=backup_retention, 119 geo_redundant_backup=geo_redundant_backup, 120 storage_mb=storage_mb, 121 storage_autogrow=auto_grow)), 122 location=location, 123 tags=tags) 124 if assign_identity: 125 parameters.identity = mysql.models.ResourceIdentity(type=mysql.models.IdentityType.system_assigned.value) 126 elif provider == 'Microsoft.DBforMariaDB': 127 engine_name = 'mariadb' 128 pricing_link = 'https://aka.ms/mariadb-pricing' 129 location, resource_group_name, server_name = generate_missing_parameters(cmd, location, resource_group_name, 130 server_name, engine_name) 131 check_name_client = cf_mariadb_check_resource_availability_sterling(cmd.cli_ctx, None) 132 name_availability_resquest = mariadb.models.NameAvailabilityRequest(name=server_name, type="Microsoft.DBforMariaDB") 133 check_server_name_availability(check_name_client, name_availability_resquest) 134 logger.warning('Creating %s Server \'%s\' in group \'%s\'...', engine_name, server_name, resource_group_name) 135 logger.warning('Your server \'%s\' is using sku \'%s\' (Paid Tier). ' 136 'Please refer to %s for pricing details', server_name, sku_name, pricing_link) 137 parameters = mariadb.models.ServerForCreate( 138 sku=mariadb.models.Sku(name=sku_name), 139 properties=mariadb.models.ServerPropertiesForDefaultCreate( 140 administrator_login=administrator_login, 141 administrator_login_password=administrator_login_password, 142 version=version, 143 ssl_enforcement=ssl_enforcement, 144 public_network_access=public_network_access, 145 storage_profile=mariadb.models.StorageProfile( 146 backup_retention_days=backup_retention, 147 geo_redundant_backup=geo_redundant_backup, 148 storage_mb=storage_mb, 149 storage_autogrow=auto_grow)), 150 location=location, 151 tags=tags) 152 153 server_result = resolve_poller( 154 client.begin_create(resource_group_name, server_name, parameters), cmd.cli_ctx, 155 '{} Server Create'.format(engine_name)) 156 user = server_result.administrator_login 157 version = server_result.version 158 host = server_result.fully_qualified_domain_name 159 160 # Adding firewall rule 161 if public_network_access is not None and start_ip != '': 162 firewall_id = create_firewall_rule(cmd, resource_group_name, server_name, start_ip, end_ip, engine_name) 163 164 logger.warning('Make a note of your password. If you forget, you would have to ' 165 'reset your password with \'az %s server update -n %s -g %s -p <new-password>\'.', 166 engine_name, server_name, resource_group_name) 167 168 update_local_contexts(cmd, provider, server_name, resource_group_name, location, user) 169 170 if engine_name == 'postgres': 171 return form_response(server_result, administrator_login_password if administrator_login_password is not None else '*****', 172 host=host, 173 connection_string=create_postgresql_connection_string(server_name, host, user, administrator_login_password), 174 database_name=None, firewall_id=firewall_id) 175 # Serves both - MySQL and MariaDB 176 # Create mysql database if it does not exist 177 database_name = DEFAULT_DB_NAME 178 create_database(cmd, resource_group_name, server_name, database_name, engine_name) 179 return form_response(server_result, administrator_login_password if administrator_login_password is not None else '*****', 180 host=host, 181 connection_string=create_mysql_connection_string(server_name, host, database_name, user, administrator_login_password), 182 database_name=database_name, firewall_id=firewall_id) 183 184 185# Need to replace source server name with source server id, so customer server restore function 186# The parameter list should be the same as that in factory to use the ParametersContext 187# arguments and validators 188def _server_restore(cmd, client, resource_group_name, server_name, source_server, restore_point_in_time, no_wait=False): 189 provider = 'Microsoft.DBforPostgreSQL' 190 if isinstance(client, MySqlServersOperations): 191 provider = 'Microsoft.DBforMySQL' 192 elif isinstance(client, MariaDBServersOperations): 193 provider = 'Microsoft.DBforMariaDB' 194 195 parameters = None 196 if not is_valid_resource_id(source_server): 197 if len(source_server.split('/')) == 1: 198 source_server = resource_id( 199 subscription=get_subscription_id(cmd.cli_ctx), 200 resource_group=resource_group_name, 201 namespace=provider, 202 type='servers', 203 name=source_server) 204 else: 205 raise ValueError('The provided source-server {} is invalid.'.format(source_server)) 206 207 if provider == 'Microsoft.DBforMySQL': 208 parameters = mysql.models.ServerForCreate( 209 properties=mysql.models.ServerPropertiesForRestore( 210 source_server_id=source_server, 211 restore_point_in_time=restore_point_in_time), 212 location=None) 213 elif provider == 'Microsoft.DBforPostgreSQL': 214 parameters = postgresql.models.ServerForCreate( 215 properties=postgresql.models.ServerPropertiesForRestore( 216 source_server_id=source_server, 217 restore_point_in_time=restore_point_in_time), 218 location=None) 219 elif provider == 'Microsoft.DBforMariaDB': 220 parameters = mariadb.models.ServerForCreate( 221 properties=mariadb.models.ServerPropertiesForRestore( 222 source_server_id=source_server, 223 restore_point_in_time=restore_point_in_time), 224 location=None) 225 226 parameters.properties.source_server_id = source_server 227 parameters.properties.restore_point_in_time = restore_point_in_time 228 229 # Here is a workaround that we don't support cross-region restore currently, 230 # so the location must be set as the same as source server (not the resource group) 231 id_parts = parse_resource_id(source_server) 232 try: 233 source_server_object = client.get(id_parts['resource_group'], id_parts['name']) 234 parameters.location = source_server_object.location 235 except Exception as e: 236 raise ValueError('Unable to get source server: {}.'.format(str(e))) 237 238 return sdk_no_wait(no_wait, client.begin_create, resource_group_name, server_name, parameters) 239 240 241# need to replace source server name with source server id, so customer server georestore function 242# The parameter list should be the same as that in factory to use the ParametersContext 243# auguments and validators 244def _server_georestore(cmd, client, resource_group_name, server_name, sku_name, location, source_server, 245 backup_retention=None, geo_redundant_backup=None, no_wait=False, **kwargs): 246 provider = 'Microsoft.DBforPostgreSQL' 247 if isinstance(client, MySqlServersOperations): 248 provider = 'Microsoft.DBforMySQL' 249 elif isinstance(client, MariaDBServersOperations): 250 provider = 'Microsoft.DBforMariaDB' 251 252 parameters = None 253 254 if not is_valid_resource_id(source_server): 255 if len(source_server.split('/')) == 1: 256 source_server = resource_id(subscription=get_subscription_id(cmd.cli_ctx), 257 resource_group=resource_group_name, 258 namespace=provider, 259 type='servers', 260 name=source_server) 261 else: 262 raise ValueError('The provided source-server {} is invalid.'.format(source_server)) 263 264 if provider == 'Microsoft.DBforMySQL': 265 parameters = mysql.models.ServerForCreate( 266 sku=mysql.models.Sku(name=sku_name), 267 properties=mysql.models.ServerPropertiesForGeoRestore( 268 storage_profile=mysql.models.StorageProfile( 269 backup_retention_days=backup_retention, 270 geo_redundant_backup=geo_redundant_backup), 271 source_server_id=source_server), 272 location=location) 273 elif provider == 'Microsoft.DBforPostgreSQL': 274 parameters = postgresql.models.ServerForCreate( 275 sku=postgresql.models.Sku(name=sku_name), 276 properties=postgresql.models.ServerPropertiesForGeoRestore( 277 storage_profile=postgresql.models.StorageProfile( 278 backup_retention_days=backup_retention, 279 geo_redundant_backup=geo_redundant_backup), 280 source_server_id=source_server), 281 location=location) 282 elif provider == 'Microsoft.DBforMariaDB': 283 parameters = mariadb.models.ServerForCreate( 284 sku=mariadb.models.Sku(name=sku_name), 285 properties=mariadb.models.ServerPropertiesForGeoRestore( 286 storage_profile=mariadb.models.StorageProfile( 287 backup_retention_days=backup_retention, 288 geo_redundant_backup=geo_redundant_backup), 289 source_server_id=source_server), 290 location=location) 291 292 parameters.properties.source_server_id = source_server 293 294 source_server_id_parts = parse_resource_id(source_server) 295 try: 296 source_server_object = client.get(source_server_id_parts['resource_group'], source_server_id_parts['name']) 297 if parameters.sku.name is None: 298 parameters.sku.name = source_server_object.sku.name 299 except Exception as e: 300 raise ValueError('Unable to get source server: {}.'.format(str(e))) 301 302 return sdk_no_wait(no_wait, client.begin_create, resource_group_name, server_name, parameters) 303 304 305# Custom functions for server replica, will add PostgreSQL part after backend ready in future 306def _replica_create(cmd, client, resource_group_name, server_name, source_server, no_wait=False, location=None, sku_name=None, **kwargs): 307 provider = 'Microsoft.DBforPostgreSQL' 308 if isinstance(client, MySqlServersOperations): 309 provider = 'Microsoft.DBforMySQL' 310 elif isinstance(client, MariaDBServersOperations): 311 provider = 'Microsoft.DBforMariaDB' 312 # set source server id 313 if not is_valid_resource_id(source_server): 314 if len(source_server.split('/')) == 1: 315 source_server = resource_id(subscription=get_subscription_id(cmd.cli_ctx), 316 resource_group=resource_group_name, 317 namespace=provider, 318 type='servers', 319 name=source_server) 320 else: 321 raise CLIError('The provided source-server {} is invalid.'.format(source_server)) 322 323 source_server_id_parts = parse_resource_id(source_server) 324 try: 325 source_server_object = client.get(source_server_id_parts['resource_group'], source_server_id_parts['name']) 326 except CloudError as e: 327 raise CLIError('Unable to get source server: {}.'.format(str(e))) 328 329 if location is None: 330 location = source_server_object.location 331 332 if sku_name is None: 333 sku_name = source_server_object.sku.name 334 335 parameters = None 336 if provider == 'Microsoft.DBforMySQL': 337 parameters = mysql.models.ServerForCreate( 338 sku=mysql.models.Sku(name=sku_name), 339 properties=mysql.models.ServerPropertiesForReplica(source_server_id=source_server), 340 location=location) 341 elif provider == 'Microsoft.DBforPostgreSQL': 342 parameters = postgresql.models.ServerForCreate( 343 sku=postgresql.models.Sku(name=sku_name), 344 properties=postgresql.models.ServerPropertiesForReplica(source_server_id=source_server), 345 location=location) 346 elif provider == 'Microsoft.DBforMariaDB': 347 parameters = mariadb.models.ServerForCreate( 348 sku=mariadb.models.Sku(name=sku_name), 349 properties=mariadb.models.ServerPropertiesForReplica(source_server_id=source_server), 350 location=location) 351 352 return sdk_no_wait(no_wait, client.begin_create, resource_group_name, server_name, parameters) 353 354 355def _replica_stop(client, resource_group_name, server_name): 356 try: 357 server_object = client.get(resource_group_name, server_name) 358 except Exception as e: 359 raise CLIError('Unable to get server: {}.'.format(str(e))) 360 361 if server_object.replication_role.lower() != "replica": 362 raise CLIError('Server {} is not a replica server.'.format(server_name)) 363 364 server_module_path = server_object.__module__ 365 module = import_module(server_module_path.replace('server', 'server_update_parameters')) 366 ServerUpdateParameters = getattr(module, 'ServerUpdateParameters') 367 368 params = ServerUpdateParameters(replication_role='None') 369 370 return client.begin_update(resource_group_name, server_name, params) 371 372 373def _server_update_custom_func(instance, 374 sku_name=None, 375 storage_mb=None, 376 backup_retention=None, 377 administrator_login_password=None, 378 ssl_enforcement=None, 379 tags=None, 380 auto_grow=None, 381 assign_identity=False, 382 public_network_access=None, 383 minimal_tls_version=None): 384 server_module_path = instance.__module__ 385 module = import_module(server_module_path.replace('server', 'server_update_parameters')) 386 ServerUpdateParameters = getattr(module, 'ServerUpdateParameters') 387 388 if sku_name: 389 instance.sku.name = sku_name 390 instance.sku.capacity = None 391 instance.sku.family = None 392 instance.sku.tier = None 393 else: 394 instance.sku = None 395 396 if storage_mb: 397 instance.storage_profile.storage_mb = storage_mb 398 399 if backup_retention: 400 instance.storage_profile.backup_retention_days = backup_retention 401 402 if auto_grow: 403 instance.storage_profile.storage_autogrow = auto_grow 404 405 params = ServerUpdateParameters(sku=instance.sku, 406 storage_profile=instance.storage_profile, 407 administrator_login_password=administrator_login_password, 408 version=None, 409 ssl_enforcement=ssl_enforcement, 410 tags=tags, 411 public_network_access=public_network_access, 412 minimal_tls_version=minimal_tls_version) 413 414 if assign_identity: 415 if server_module_path.find('postgres'): 416 if instance.identity is None: 417 instance.identity = postgresql.models.ResourceIdentity(type=postgresql.models.IdentityType.system_assigned.value) 418 params.identity = instance.identity 419 elif server_module_path.find('mysql'): 420 if instance.identity is None: 421 instance.identity = mysql.models.ResourceIdentity(type=mysql.models.IdentityType.system_assigned.value) 422 params.identity = instance.identity 423 424 return params 425 426 427def _server_mysql_upgrade(cmd, client, resource_group_name, server_name, target_server_version): 428 parameters = mysql.models.ServerUpgradeParameters( 429 target_server_version=target_server_version 430 ) 431 432 client.begin_upgrade(resource_group_name, server_name, parameters) 433 434 435def _server_mariadb_get(cmd, resource_group_name, server_name): 436 client = get_mariadb_management_client(cmd.cli_ctx) 437 return client.servers.get(resource_group_name, server_name) 438 439 440def _server_mysql_get(cmd, resource_group_name, server_name): 441 client = get_mysql_management_client(cmd.cli_ctx) 442 return client.servers.get(resource_group_name, server_name) 443 444 445def _server_stop(cmd, client, resource_group_name, server_name): 446 logger.warning("Server will be automatically started after 7 days " 447 "if you do not perform a manual start operation") 448 return client.begin_stop(resource_group_name, server_name) 449 450 451def _server_postgresql_get(cmd, resource_group_name, server_name): 452 client = get_postgresql_management_client(cmd.cli_ctx) 453 return client.servers.get(resource_group_name, server_name) 454 455 456def _server_update_get(client, resource_group_name, server_name): 457 return client.get(resource_group_name, server_name) 458 459 460def _server_update_set(client, resource_group_name, server_name, parameters): 461 return client.begin_update(resource_group_name, server_name, parameters) 462 463 464def _server_delete(cmd, client, resource_group_name, server_name): 465 database_engine = 'postgres' 466 if isinstance(client, MySqlServersOperations): 467 database_engine = 'mysql' 468 469 result = client.begin_delete(resource_group_name, server_name) 470 471 if cmd.cli_ctx.local_context.is_on: 472 local_context_file = cmd.cli_ctx.local_context._get_local_context_file() # pylint: disable=protected-access 473 local_context_file.remove_option('{}'.format(database_engine), 'server_name') 474 475 return result.result() 476 477 478def _get_sku_name(tier, family, capacity): 479 return '{}_{}_{}'.format(SKU_TIER_MAP[tier], family, str(capacity)) 480 481 482def _firewall_rule_create(client, resource_group_name, server_name, firewall_rule_name, start_ip_address, end_ip_address): 483 484 parameters = {'name': firewall_rule_name, 'start_ip_address': start_ip_address, 'end_ip_address': end_ip_address} 485 486 return client.begin_create_or_update(resource_group_name, server_name, firewall_rule_name, parameters) 487 488 489def _firewall_rule_custom_getter(client, resource_group_name, server_name, firewall_rule_name): 490 return client.get(resource_group_name, server_name, firewall_rule_name) 491 492 493def _firewall_rule_custom_setter(client, resource_group_name, server_name, firewall_rule_name, parameters): 494 return client.begin_create_or_update( 495 resource_group_name, 496 server_name, 497 firewall_rule_name, 498 parameters) 499 500 501def _firewall_rule_update_custom_func(instance, start_ip_address=None, end_ip_address=None): 502 if start_ip_address is not None: 503 instance.start_ip_address = start_ip_address 504 if end_ip_address is not None: 505 instance.end_ip_address = end_ip_address 506 return instance 507 508 509def _vnet_rule_create(client, resource_group_name, server_name, virtual_network_rule_name, virtual_network_subnet_id, ignore_missing_vnet_service_endpoint=None): 510 511 parameters = { 512 'name': virtual_network_rule_name, 513 'virtual_network_subnet_id': virtual_network_subnet_id, 514 'ignore_missing_vnet_service_endpoint': ignore_missing_vnet_service_endpoint 515 } 516 517 return client.begin_create_or_update(resource_group_name, server_name, virtual_network_rule_name, parameters) 518 519 520def _custom_vnet_update_getter(client, resource_group_name, server_name, virtual_network_rule_name): 521 return client.get(resource_group_name, server_name, virtual_network_rule_name) 522 523 524def _custom_vnet_update_setter(client, resource_group_name, server_name, virtual_network_rule_name, parameters): 525 return client.begin_create_or_update( 526 resource_group_name, 527 server_name, 528 virtual_network_rule_name, 529 parameters) 530 531 532def _vnet_rule_update_custom_func(instance, virtual_network_subnet_id, ignore_missing_vnet_service_endpoint=None): 533 534 instance.virtual_network_subnet_id = virtual_network_subnet_id 535 if ignore_missing_vnet_service_endpoint is not None: 536 instance.ignore_missing_vnet_service_endpoint = ignore_missing_vnet_service_endpoint 537 return instance 538 539 540def _configuration_update(client, resource_group_name, server_name, configuration_name, value=None, source=None): 541 542 parameters = { 543 'name': configuration_name, 544 'value': value, 545 'source': source 546 } 547 548 return client.begin_create_or_update(resource_group_name, server_name, configuration_name, parameters) 549 550 551def _db_create(client, resource_group_name, server_name, database_name, charset=None, collation=None): 552 553 parameters = { 554 'name': database_name, 555 'charset': charset, 556 'collation': collation 557 } 558 559 return client.begin_create_or_update(resource_group_name, server_name, database_name, parameters) 560 561 562# Custom functions for server logs 563def _download_log_files( 564 client, 565 resource_group_name, 566 server_name, 567 file_name): 568 569 # list all files 570 files = client.list_by_server(resource_group_name, server_name) 571 572 for f in files: 573 if f.name in file_name: 574 urlretrieve(f.url, f.name) 575 576 577def _list_log_files_with_filter(client, resource_group_name, server_name, filename_contains=None, 578 file_last_written=None, max_file_size=None): 579 580 # list all files 581 all_files = client.list_by_server(resource_group_name, server_name) 582 files = [] 583 584 if file_last_written is None: 585 file_last_written = 72 586 time_line = datetime.utcnow().replace(tzinfo=tzutc()) - timedelta(hours=file_last_written) 587 588 for f in all_files: 589 if f.last_modified_time < time_line: 590 continue 591 if filename_contains is not None and re.search(filename_contains, f.name) is None: 592 continue 593 if max_file_size is not None and f.size_in_kb > max_file_size: 594 continue 595 596 del f.created_time 597 files.append(f) 598 599 return files 600 601 602# Custom functions for list servers 603def _server_list_custom_func(client, resource_group_name=None): 604 if resource_group_name: 605 return client.list_by_resource_group(resource_group_name) 606 return client.list() 607 608 609# region private_endpoint 610def _update_private_endpoint_connection_status(cmd, client, resource_group_name, server_name, 611 private_endpoint_connection_name, is_approved=True, description=None): # pylint: disable=unused-argument 612 private_endpoint_connection = client.get(resource_group_name=resource_group_name, server_name=server_name, 613 private_endpoint_connection_name=private_endpoint_connection_name) 614 new_status = 'Approved' if is_approved else 'Rejected' 615 616 private_link_service_connection_state = { 617 'status': new_status, 618 'description': description 619 } 620 621 private_endpoint_connection.private_link_service_connection_state = private_link_service_connection_state 622 623 return client.begin_create_or_update(resource_group_name=resource_group_name, 624 server_name=server_name, 625 private_endpoint_connection_name=private_endpoint_connection_name, 626 parameters=private_endpoint_connection) 627 628 629def approve_private_endpoint_connection(cmd, client, resource_group_name, server_name, private_endpoint_connection_name, 630 description=None): 631 """Approve a private endpoint connection request for a server.""" 632 633 return _update_private_endpoint_connection_status( 634 cmd, client, resource_group_name, server_name, private_endpoint_connection_name, is_approved=True, 635 description=description) 636 637 638def reject_private_endpoint_connection(cmd, client, resource_group_name, server_name, private_endpoint_connection_name, 639 description=None): 640 """Reject a private endpoint connection request for a server.""" 641 642 return _update_private_endpoint_connection_status( 643 cmd, client, resource_group_name, server_name, private_endpoint_connection_name, is_approved=False, 644 description=description) 645 646 647def server_key_create(client, resource_group_name, server_name, kid): 648 649 """Create Server Key.""" 650 651 key_name = _get_server_key_name_from_uri(kid) 652 653 parameters = { 654 'uri': kid, 655 'server_key_type': "AzureKeyVault" 656 } 657 658 return client.begin_create_or_update(server_name, key_name, resource_group_name, parameters) 659 660 661def server_key_get(client, resource_group_name, server_name, kid): 662 663 """Get Server Key.""" 664 665 key_name = _get_server_key_name_from_uri(kid) 666 667 return client.get( 668 resource_group_name=resource_group_name, 669 server_name=server_name, 670 key_name=key_name) 671 672 673def server_key_delete(cmd, client, resource_group_name, server_name, kid): 674 675 """Drop Server Key.""" 676 key_name = _get_server_key_name_from_uri(kid) 677 678 return client.begin_delete( 679 resource_group_name=resource_group_name, 680 server_name=server_name, 681 key_name=key_name) 682 683 684def _get_server_key_name_from_uri(uri): 685 ''' 686 Gets the key's name to use as a server key. 687 688 The SQL server key API requires that the server key has a specific name 689 based on the vault, key and key version. 690 ''' 691 692 match = re.match(r'https://(.)+\.(managedhsm.azure.net|managedhsm-preview.azure.net|vault.azure.net|vault-int.azure-int.net|vault.azure.cn|managedhsm.azure.cn|vault.usgovcloudapi.net|managedhsm.usgovcloudapi.net|vault.microsoftazure.de|managedhsm.microsoftazure.de|vault.cloudapi.eaglex.ic.gov|vault.cloudapi.microsoft.scloud)(:443)?\/keys/[^\/]+\/[0-9a-zA-Z]+$', uri) 693 694 if match is None: 695 raise CLIError('The provided uri is invalid. Please provide a valid Azure Key Vault key id. For example: ' 696 '"https://YourVaultName.vault.azure.net/keys/YourKeyName/01234567890123456789012345678901" or "https://YourManagedHsmRegion.YourManagedHsmName.managedhsm.azure.net/keys/YourKeyName/01234567890123456789012345678901"') 697 698 vault = uri.split('.')[0].split('/')[-1] 699 key = uri.split('/')[-2] 700 version = uri.split('/')[-1] 701 return '{}_{}_{}'.format(vault, key, version) 702 703 704def server_ad_admin_set(client, resource_group_name, server_name, login=None, sid=None): 705 ''' 706 Sets a server's AD admin. 707 ''' 708 709 parameters = { 710 'login': login, 711 'sid': sid, 712 'tenant_id': _get_tenant_id() 713 } 714 715 return client.begin_create_or_update( 716 server_name=server_name, 717 resource_group_name=resource_group_name, 718 properties=parameters) 719 720 721def _get_tenant_id(): 722 ''' 723 Gets tenantId from current subscription. 724 ''' 725 profile = Profile() 726 sub = profile.get_subscription() 727 return sub['tenantId'] 728# endregion 729 730 731# region new create experience 732def create_mysql_connection_string(server_name, host, database_name, user_name, password): 733 connection_kwargs = { 734 'host': host, 735 'dbname': database_name, 736 'username': user_name, 737 'servername': server_name, 738 'password': password if password is not None else '{password}' 739 } 740 return 'mysql {dbname} --host {host} --user {username}@{servername} --password={password}'.format(**connection_kwargs) 741 742 743def create_database(cmd, resource_group_name, server_name, database_name, engine_name): 744 if engine_name == 'mysql': 745 # check for existing database, create if not present 746 database_client = cf_mysql_db(cmd.cli_ctx, None) 747 elif engine_name == 'mariadb': 748 database_client = cf_mariadb_db(cmd.cli_ctx, None) 749 parameters = { 750 'name': database_name, 751 'charset': 'utf8' 752 } 753 try: 754 database_client.get(resource_group_name, server_name, database_name) 755 except ResourceNotFoundError: 756 logger.warning('Creating %s database \'%s\'...', engine_name, database_name) 757 database_client.begin_create_or_update(resource_group_name, server_name, database_name, parameters) 758 759 760def form_response(server_result, password, host, connection_string, database_name=None, firewall_id=None): 761 result = todict(server_result) 762 result['connectionString'] = connection_string 763 result['password'] = password 764 if firewall_id is not None: 765 result['firewallName'] = firewall_id 766 if database_name is not None: 767 result['databaseName'] = database_name 768 return result 769 770 771def create_postgresql_connection_string(server_name, host, user, password): 772 connection_kwargs = { 773 'user': user, 774 'host': host, 775 'servername': server_name, 776 'password': password if password is not None else '{password}' 777 } 778 return 'postgres://{user}%40{servername}:{password}@{host}/postgres?sslmode=require'.format(**connection_kwargs) 779 780 781def check_server_name_availability(check_name_client, parameters): 782 server_availability = check_name_client.execute(parameters) 783 if not server_availability.name_available: 784 raise CLIError("The server name '{}' already exists.Please re-run command with some " 785 "other server name.".format(parameters.name)) 786 return True 787 788 789def update_local_contexts(cmd, provider, server_name, resource_group_name, location, user): 790 engine = 'postgres' 791 if provider == 'Microsoft.DBforMySQL': 792 engine = 'mysql' 793 elif provider == 'Microsoft.DBforMariaDB': 794 engine = 'mariadb' 795 796 if cmd.cli_ctx.local_context.is_on: 797 cmd.cli_ctx.local_context.set([engine], 'server_name', 798 server_name) # Setting the server name in the local context 799 cmd.cli_ctx.local_context.set([engine], 'administrator_login', 800 user) # Setting the server name in the local context 801 cmd.cli_ctx.local_context.set([ALL], 'location', 802 location) # Setting the location in the local context 803 cmd.cli_ctx.local_context.set([ALL], 'resource_group_name', resource_group_name) 804 805 806def get_connection_string(cmd, client, server_name='{server}', database_name='{database}', administrator_login='{username}', administrator_login_password='{password}'): 807 provider = 'MySQL' 808 if isinstance(client, PostgreSQLLocationOperations): 809 provider = 'PostgreSQL' 810 elif isinstance(client, MariaDBLocationOperations): 811 provider = 'MariaDB' 812 813 if provider == 'MySQL': 814 server_endpoint = cmd.cli_ctx.cloud.suffixes.mysql_server_endpoint 815 host = '{}{}'.format(server_name, server_endpoint) 816 result = { 817 'mysql_cmd': "mysql {database} --host {host} --user {user}@{server} --password={password}", 818 'ado.net': "Server={host}; Port=3306; Database={database}; Uid={user}@{server}; Pwd={password}", 819 'jdbc': "jdbc:mysql://{host}:3306/{database}?user={user}@{server}&password={password}", 820 'node.js': "var conn = mysql.createConnection({{host: '{host}', user: '{user}@{server}'," 821 "password: {password}, database: {database}, port: 3306}});", 822 'php': "host={host} port=3306 dbname={database} user={user}@{server} password={password}", 823 'python': "cnx = mysql.connector.connect(user='{user}@{server}', password='{password}', host='{host}', " 824 "port=3306, database='{database}')", 825 'ruby': "client = Mysql2::Client.new(username: '{user}@{server}', password: '{password}', " 826 "database: '{database}', host: '{host}', port: 3306)" 827 } 828 829 connection_kwargs = { 830 'host': host, 831 'user': administrator_login, 832 'password': administrator_login_password if administrator_login_password is not None else '{password}', 833 'database': database_name, 834 'server': server_name 835 } 836 837 for k, v in result.items(): 838 result[k] = v.format(**connection_kwargs) 839 840 if provider == 'PostgreSQL': 841 server_endpoint = cmd.cli_ctx.cloud.suffixes.postgresql_server_endpoint 842 host = '{}{}'.format(server_name, server_endpoint) 843 result = { 844 'psql_cmd': "postgresql://{user}@{server}:{password}@{host}/{database}?sslmode=require", 845 'C++ (libpq)': "host={host} port=5432 dbname={database} user={user}@{server} password={password} sslmode=require", 846 'ado.net': "Server={host};Database={database};Port=5432;User Id={user}@{server};Password={password};", 847 'jdbc': "jdbc:postgresql://{host}:5432/{database}?user={user}@{server}&password={password}", 848 'node.js': "var client = new pg.Client('postgres://{user}@{server}:{password}@{host}:5432/{database}');", 849 'php': "host={host} port=5432 dbname={database} user={user}@{server} password={password}", 850 'python': "cnx = psycopg2.connect(database='{database}', user='{user}@{server}', host='{host}', password='{password}', " 851 "port='5432')", 852 'ruby': "cnx = PG::Connection.new(:host => '{host}', :user => '{user}@{server}', :dbname => '{database}', " 853 ":port => '5432', :password => '{password}')" 854 } 855 856 connection_kwargs = { 857 'host': host, 858 'user': administrator_login, 859 'password': administrator_login_password if administrator_login_password is not None else '{password}', 860 'database': database_name, 861 'server': server_name 862 } 863 864 for k, v in result.items(): 865 result[k] = v.format(**connection_kwargs) 866 867 if provider == 'MariaDB': 868 server_endpoint = cmd.cli_ctx.cloud.suffixes.mariadb_server_endpoint 869 host = '{}{}'.format(server_name, server_endpoint) 870 result = { 871 'ado.net': "Server={host}; Port=3306; Database={database}; Uid={user}@{server}; Pwd={password}", 872 'jdbc': "jdbc:mariadb://{host}:3306/{database}?user={user}@{server}&password={password}", 873 'node.js': "var conn = mysql.createConnection({{host: '{host}', user: '{user}@{server}'," 874 "password: {password}, database: {database}, port: 3306}});", 875 'php': "host={host} port=3306 dbname={database} user={user}@{server} password={password}", 876 'python': "cnx = mysql.connector.connect(user='{user}@{server}', password='{password}', host='{host}', " 877 "port=3306, database='{database}')", 878 'ruby': "client = Mysql2::Client.new(username: '{user}@{server}', password: '{password}', " 879 "database: '{database}', host: '{host}', port: 3306)" 880 } 881 882 connection_kwargs = { 883 'host': host, 884 'user': administrator_login, 885 'password': administrator_login_password if administrator_login_password is not None else '{password}', 886 'database': database_name, 887 'server': server_name 888 } 889 890 for k, v in result.items(): 891 result[k] = v.format(**connection_kwargs) 892 893 return { 894 'connectionStrings': result 895 } 896