1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6# pylint: disable=unused-argument, line-too-long
7from datetime import datetime, timedelta
8from importlib import import_module
9import re
10from dateutil.tz import tzutc   # pylint: disable=import-error
11from msrestazure.azure_exceptions import CloudError
12from msrestazure.tools import resource_id, is_valid_resource_id, parse_resource_id  # pylint: disable=import-error
13from knack.log import get_logger
14from knack.util import todict
15from six.moves.urllib.request import urlretrieve  # pylint: disable=import-error
16from azure.core.exceptions import ResourceNotFoundError
17from azure.cli.core._profile import Profile
18from azure.cli.core.commands.client_factory import get_subscription_id
19from azure.cli.core.util import CLIError, sdk_no_wait
20from azure.cli.core.local_context import ALL
21from azure.mgmt.rdbms import postgresql, mysql, mariadb
22from azure.mgmt.rdbms.mysql.operations._servers_operations import ServersOperations as MySqlServersOperations
23from azure.mgmt.rdbms.postgresql.operations._location_based_performance_tier_operations import LocationBasedPerformanceTierOperations as PostgreSQLLocationOperations
24from azure.mgmt.rdbms.mariadb.operations._servers_operations import ServersOperations as MariaDBServersOperations
25from azure.mgmt.rdbms.mariadb.operations._location_based_performance_tier_operations import LocationBasedPerformanceTierOperations as MariaDBLocationOperations
26from ._client_factory import get_mariadb_management_client, get_mysql_management_client, cf_mysql_db, cf_mariadb_db, \
27    get_postgresql_management_client, cf_postgres_check_resource_availability_sterling, \
28    cf_mysql_check_resource_availability_sterling, cf_mariadb_check_resource_availability_sterling
29from ._flexible_server_util import generate_missing_parameters, generate_password, resolve_poller
30from ._util import parse_public_network_access_input, create_firewall_rule
31
32logger = get_logger(__name__)
33
34
35SKU_TIER_MAP = {'Basic': 'b', 'GeneralPurpose': 'gp', 'MemoryOptimized': 'mo'}
36DEFAULT_DB_NAME = 'defaultdb'
37
38
39# pylint: disable=too-many-locals, too-many-statements, raise-missing-from
40def _server_create(cmd, client, resource_group_name=None, server_name=None, sku_name=None, no_wait=False,
41                   location=None, administrator_login=None, administrator_login_password=None, backup_retention=None,
42                   geo_redundant_backup=None, ssl_enforcement=None, storage_mb=None, tags=None, version=None, auto_grow='Enabled',
43                   assign_identity=False, public_network_access=None, infrastructure_encryption=None, minimal_tls_version=None):
44    provider = 'Microsoft.DBforPostgreSQL'
45    if isinstance(client, MySqlServersOperations):
46        provider = 'Microsoft.DBforMySQL'
47    elif isinstance(client, MariaDBServersOperations):
48        provider = 'Microsoft.DBforMariaDB'
49
50    server_result = firewall_id = None
51    administrator_login_password = generate_password(administrator_login_password)
52    engine_name = 'postgres'
53    pricing_link = 'https://aka.ms/postgres-pricing'
54    start_ip = end_ip = ''
55
56    if public_network_access is not None and str(public_network_access).lower() != 'enabled' and str(public_network_access).lower() != 'disabled':
57        if str(public_network_access).lower() == 'all':
58            start_ip, end_ip = '0.0.0.0', '255.255.255.255'
59        else:
60            start_ip, end_ip = parse_public_network_access_input(public_network_access)
61        # if anything but 'disabled' is passed on to the args,
62        # then the public network access value passed on to the API is Enabled.
63        public_network_access = 'Enabled'
64
65    # Check availability for server name if it is supplied by the user
66    if provider == 'Microsoft.DBforPostgreSQL':
67        # Populate desired parameters
68        location, resource_group_name, server_name = generate_missing_parameters(cmd, location, resource_group_name,
69                                                                                 server_name, engine_name)
70        check_name_client = cf_postgres_check_resource_availability_sterling(cmd.cli_ctx, None)
71        name_availability_resquest = postgresql.models.NameAvailabilityRequest(name=server_name, type="Microsoft.DBforPostgreSQL/servers")
72        check_server_name_availability(check_name_client, name_availability_resquest)
73        logger.warning('Creating %s Server \'%s\' in group \'%s\'...', engine_name, server_name, resource_group_name)
74        logger.warning('Your server \'%s\' is using sku \'%s\' (Paid Tier). '
75                       'Please refer to %s  for pricing details', server_name, sku_name, pricing_link)
76        parameters = postgresql.models.ServerForCreate(
77            sku=postgresql.models.Sku(name=sku_name),
78            properties=postgresql.models.ServerPropertiesForDefaultCreate(
79                administrator_login=administrator_login,
80                administrator_login_password=administrator_login_password,
81                version=version,
82                ssl_enforcement=ssl_enforcement,
83                minimal_tls_version=minimal_tls_version,
84                public_network_access=public_network_access,
85                infrastructure_encryption=infrastructure_encryption,
86                storage_profile=postgresql.models.StorageProfile(
87                    backup_retention_days=backup_retention,
88                    geo_redundant_backup=geo_redundant_backup,
89                    storage_mb=storage_mb,
90                    storage_autogrow=auto_grow)),
91            location=location,
92            tags=tags)
93        if assign_identity:
94            parameters.identity = postgresql.models.ResourceIdentity(
95                type=postgresql.models.IdentityType.system_assigned.value)
96    elif provider == 'Microsoft.DBforMySQL':
97        engine_name = 'mysql'
98        pricing_link = 'https://aka.ms/mysql-pricing'
99        location, resource_group_name, server_name = generate_missing_parameters(cmd, location, resource_group_name,
100                                                                                 server_name, engine_name)
101        check_name_client = cf_mysql_check_resource_availability_sterling(cmd.cli_ctx, None)
102        name_availability_resquest = mysql.models.NameAvailabilityRequest(name=server_name, type="Microsoft.DBforMySQL/servers")
103        check_server_name_availability(check_name_client, name_availability_resquest)
104        logger.warning('Creating %s Server \'%s\' in group \'%s\'...', engine_name, server_name, resource_group_name)
105        logger.warning('Your server \'%s\' is using sku \'%s\' (Paid Tier). '
106                       'Please refer to %s  for pricing details', server_name, sku_name, pricing_link)
107        parameters = mysql.models.ServerForCreate(
108            sku=mysql.models.Sku(name=sku_name),
109            properties=mysql.models.ServerPropertiesForDefaultCreate(
110                administrator_login=administrator_login,
111                administrator_login_password=administrator_login_password,
112                version=version,
113                ssl_enforcement=ssl_enforcement,
114                minimal_tls_version=minimal_tls_version,
115                public_network_access=public_network_access,
116                infrastructure_encryption=infrastructure_encryption,
117                storage_profile=mysql.models.StorageProfile(
118                    backup_retention_days=backup_retention,
119                    geo_redundant_backup=geo_redundant_backup,
120                    storage_mb=storage_mb,
121                    storage_autogrow=auto_grow)),
122            location=location,
123            tags=tags)
124        if assign_identity:
125            parameters.identity = mysql.models.ResourceIdentity(type=mysql.models.IdentityType.system_assigned.value)
126    elif provider == 'Microsoft.DBforMariaDB':
127        engine_name = 'mariadb'
128        pricing_link = 'https://aka.ms/mariadb-pricing'
129        location, resource_group_name, server_name = generate_missing_parameters(cmd, location, resource_group_name,
130                                                                                 server_name, engine_name)
131        check_name_client = cf_mariadb_check_resource_availability_sterling(cmd.cli_ctx, None)
132        name_availability_resquest = mariadb.models.NameAvailabilityRequest(name=server_name, type="Microsoft.DBforMariaDB")
133        check_server_name_availability(check_name_client, name_availability_resquest)
134        logger.warning('Creating %s Server \'%s\' in group \'%s\'...', engine_name, server_name, resource_group_name)
135        logger.warning('Your server \'%s\' is using sku \'%s\' (Paid Tier). '
136                       'Please refer to %s  for pricing details', server_name, sku_name, pricing_link)
137        parameters = mariadb.models.ServerForCreate(
138            sku=mariadb.models.Sku(name=sku_name),
139            properties=mariadb.models.ServerPropertiesForDefaultCreate(
140                administrator_login=administrator_login,
141                administrator_login_password=administrator_login_password,
142                version=version,
143                ssl_enforcement=ssl_enforcement,
144                public_network_access=public_network_access,
145                storage_profile=mariadb.models.StorageProfile(
146                    backup_retention_days=backup_retention,
147                    geo_redundant_backup=geo_redundant_backup,
148                    storage_mb=storage_mb,
149                    storage_autogrow=auto_grow)),
150            location=location,
151            tags=tags)
152
153    server_result = resolve_poller(
154        client.begin_create(resource_group_name, server_name, parameters), cmd.cli_ctx,
155        '{} Server Create'.format(engine_name))
156    user = server_result.administrator_login
157    version = server_result.version
158    host = server_result.fully_qualified_domain_name
159
160    # Adding firewall rule
161    if public_network_access is not None and start_ip != '':
162        firewall_id = create_firewall_rule(cmd, resource_group_name, server_name, start_ip, end_ip, engine_name)
163
164    logger.warning('Make a note of your password. If you forget, you would have to '
165                   'reset your password with \'az %s server update -n %s -g %s -p <new-password>\'.',
166                   engine_name, server_name, resource_group_name)
167
168    update_local_contexts(cmd, provider, server_name, resource_group_name, location, user)
169
170    if engine_name == 'postgres':
171        return form_response(server_result, administrator_login_password if administrator_login_password is not None else '*****',
172                             host=host,
173                             connection_string=create_postgresql_connection_string(server_name, host, user, administrator_login_password),
174                             database_name=None, firewall_id=firewall_id)
175    # Serves both - MySQL and MariaDB
176    # Create mysql database if it does not exist
177    database_name = DEFAULT_DB_NAME
178    create_database(cmd, resource_group_name, server_name, database_name, engine_name)
179    return form_response(server_result, administrator_login_password if administrator_login_password is not None else '*****',
180                         host=host,
181                         connection_string=create_mysql_connection_string(server_name, host, database_name, user, administrator_login_password),
182                         database_name=database_name, firewall_id=firewall_id)
183
184
185# Need to replace source server name with source server id, so customer server restore function
186# The parameter list should be the same as that in factory to use the ParametersContext
187# arguments and validators
188def _server_restore(cmd, client, resource_group_name, server_name, source_server, restore_point_in_time, no_wait=False):
189    provider = 'Microsoft.DBforPostgreSQL'
190    if isinstance(client, MySqlServersOperations):
191        provider = 'Microsoft.DBforMySQL'
192    elif isinstance(client, MariaDBServersOperations):
193        provider = 'Microsoft.DBforMariaDB'
194
195    parameters = None
196    if not is_valid_resource_id(source_server):
197        if len(source_server.split('/')) == 1:
198            source_server = resource_id(
199                subscription=get_subscription_id(cmd.cli_ctx),
200                resource_group=resource_group_name,
201                namespace=provider,
202                type='servers',
203                name=source_server)
204        else:
205            raise ValueError('The provided source-server {} is invalid.'.format(source_server))
206
207    if provider == 'Microsoft.DBforMySQL':
208        parameters = mysql.models.ServerForCreate(
209            properties=mysql.models.ServerPropertiesForRestore(
210                source_server_id=source_server,
211                restore_point_in_time=restore_point_in_time),
212            location=None)
213    elif provider == 'Microsoft.DBforPostgreSQL':
214        parameters = postgresql.models.ServerForCreate(
215            properties=postgresql.models.ServerPropertiesForRestore(
216                source_server_id=source_server,
217                restore_point_in_time=restore_point_in_time),
218            location=None)
219    elif provider == 'Microsoft.DBforMariaDB':
220        parameters = mariadb.models.ServerForCreate(
221            properties=mariadb.models.ServerPropertiesForRestore(
222                source_server_id=source_server,
223                restore_point_in_time=restore_point_in_time),
224            location=None)
225
226    parameters.properties.source_server_id = source_server
227    parameters.properties.restore_point_in_time = restore_point_in_time
228
229    # Here is a workaround that we don't support cross-region restore currently,
230    # so the location must be set as the same as source server (not the resource group)
231    id_parts = parse_resource_id(source_server)
232    try:
233        source_server_object = client.get(id_parts['resource_group'], id_parts['name'])
234        parameters.location = source_server_object.location
235    except Exception as e:
236        raise ValueError('Unable to get source server: {}.'.format(str(e)))
237
238    return sdk_no_wait(no_wait, client.begin_create, resource_group_name, server_name, parameters)
239
240
241# need to replace source server name with source server id, so customer server georestore function
242# The parameter list should be the same as that in factory to use the ParametersContext
243# auguments and validators
244def _server_georestore(cmd, client, resource_group_name, server_name, sku_name, location, source_server,
245                       backup_retention=None, geo_redundant_backup=None, no_wait=False, **kwargs):
246    provider = 'Microsoft.DBforPostgreSQL'
247    if isinstance(client, MySqlServersOperations):
248        provider = 'Microsoft.DBforMySQL'
249    elif isinstance(client, MariaDBServersOperations):
250        provider = 'Microsoft.DBforMariaDB'
251
252    parameters = None
253
254    if not is_valid_resource_id(source_server):
255        if len(source_server.split('/')) == 1:
256            source_server = resource_id(subscription=get_subscription_id(cmd.cli_ctx),
257                                        resource_group=resource_group_name,
258                                        namespace=provider,
259                                        type='servers',
260                                        name=source_server)
261        else:
262            raise ValueError('The provided source-server {} is invalid.'.format(source_server))
263
264    if provider == 'Microsoft.DBforMySQL':
265        parameters = mysql.models.ServerForCreate(
266            sku=mysql.models.Sku(name=sku_name),
267            properties=mysql.models.ServerPropertiesForGeoRestore(
268                storage_profile=mysql.models.StorageProfile(
269                    backup_retention_days=backup_retention,
270                    geo_redundant_backup=geo_redundant_backup),
271                source_server_id=source_server),
272            location=location)
273    elif provider == 'Microsoft.DBforPostgreSQL':
274        parameters = postgresql.models.ServerForCreate(
275            sku=postgresql.models.Sku(name=sku_name),
276            properties=postgresql.models.ServerPropertiesForGeoRestore(
277                storage_profile=postgresql.models.StorageProfile(
278                    backup_retention_days=backup_retention,
279                    geo_redundant_backup=geo_redundant_backup),
280                source_server_id=source_server),
281            location=location)
282    elif provider == 'Microsoft.DBforMariaDB':
283        parameters = mariadb.models.ServerForCreate(
284            sku=mariadb.models.Sku(name=sku_name),
285            properties=mariadb.models.ServerPropertiesForGeoRestore(
286                storage_profile=mariadb.models.StorageProfile(
287                    backup_retention_days=backup_retention,
288                    geo_redundant_backup=geo_redundant_backup),
289                source_server_id=source_server),
290            location=location)
291
292    parameters.properties.source_server_id = source_server
293
294    source_server_id_parts = parse_resource_id(source_server)
295    try:
296        source_server_object = client.get(source_server_id_parts['resource_group'], source_server_id_parts['name'])
297        if parameters.sku.name is None:
298            parameters.sku.name = source_server_object.sku.name
299    except Exception as e:
300        raise ValueError('Unable to get source server: {}.'.format(str(e)))
301
302    return sdk_no_wait(no_wait, client.begin_create, resource_group_name, server_name, parameters)
303
304
305# Custom functions for server replica, will add PostgreSQL part after backend ready in future
306def _replica_create(cmd, client, resource_group_name, server_name, source_server, no_wait=False, location=None, sku_name=None, **kwargs):
307    provider = 'Microsoft.DBforPostgreSQL'
308    if isinstance(client, MySqlServersOperations):
309        provider = 'Microsoft.DBforMySQL'
310    elif isinstance(client, MariaDBServersOperations):
311        provider = 'Microsoft.DBforMariaDB'
312    # set source server id
313    if not is_valid_resource_id(source_server):
314        if len(source_server.split('/')) == 1:
315            source_server = resource_id(subscription=get_subscription_id(cmd.cli_ctx),
316                                        resource_group=resource_group_name,
317                                        namespace=provider,
318                                        type='servers',
319                                        name=source_server)
320        else:
321            raise CLIError('The provided source-server {} is invalid.'.format(source_server))
322
323    source_server_id_parts = parse_resource_id(source_server)
324    try:
325        source_server_object = client.get(source_server_id_parts['resource_group'], source_server_id_parts['name'])
326    except CloudError as e:
327        raise CLIError('Unable to get source server: {}.'.format(str(e)))
328
329    if location is None:
330        location = source_server_object.location
331
332    if sku_name is None:
333        sku_name = source_server_object.sku.name
334
335    parameters = None
336    if provider == 'Microsoft.DBforMySQL':
337        parameters = mysql.models.ServerForCreate(
338            sku=mysql.models.Sku(name=sku_name),
339            properties=mysql.models.ServerPropertiesForReplica(source_server_id=source_server),
340            location=location)
341    elif provider == 'Microsoft.DBforPostgreSQL':
342        parameters = postgresql.models.ServerForCreate(
343            sku=postgresql.models.Sku(name=sku_name),
344            properties=postgresql.models.ServerPropertiesForReplica(source_server_id=source_server),
345            location=location)
346    elif provider == 'Microsoft.DBforMariaDB':
347        parameters = mariadb.models.ServerForCreate(
348            sku=mariadb.models.Sku(name=sku_name),
349            properties=mariadb.models.ServerPropertiesForReplica(source_server_id=source_server),
350            location=location)
351
352    return sdk_no_wait(no_wait, client.begin_create, resource_group_name, server_name, parameters)
353
354
355def _replica_stop(client, resource_group_name, server_name):
356    try:
357        server_object = client.get(resource_group_name, server_name)
358    except Exception as e:
359        raise CLIError('Unable to get server: {}.'.format(str(e)))
360
361    if server_object.replication_role.lower() != "replica":
362        raise CLIError('Server {} is not a replica server.'.format(server_name))
363
364    server_module_path = server_object.__module__
365    module = import_module(server_module_path.replace('server', 'server_update_parameters'))
366    ServerUpdateParameters = getattr(module, 'ServerUpdateParameters')
367
368    params = ServerUpdateParameters(replication_role='None')
369
370    return client.begin_update(resource_group_name, server_name, params)
371
372
373def _server_update_custom_func(instance,
374                               sku_name=None,
375                               storage_mb=None,
376                               backup_retention=None,
377                               administrator_login_password=None,
378                               ssl_enforcement=None,
379                               tags=None,
380                               auto_grow=None,
381                               assign_identity=False,
382                               public_network_access=None,
383                               minimal_tls_version=None):
384    server_module_path = instance.__module__
385    module = import_module(server_module_path.replace('server', 'server_update_parameters'))
386    ServerUpdateParameters = getattr(module, 'ServerUpdateParameters')
387
388    if sku_name:
389        instance.sku.name = sku_name
390        instance.sku.capacity = None
391        instance.sku.family = None
392        instance.sku.tier = None
393    else:
394        instance.sku = None
395
396    if storage_mb:
397        instance.storage_profile.storage_mb = storage_mb
398
399    if backup_retention:
400        instance.storage_profile.backup_retention_days = backup_retention
401
402    if auto_grow:
403        instance.storage_profile.storage_autogrow = auto_grow
404
405    params = ServerUpdateParameters(sku=instance.sku,
406                                    storage_profile=instance.storage_profile,
407                                    administrator_login_password=administrator_login_password,
408                                    version=None,
409                                    ssl_enforcement=ssl_enforcement,
410                                    tags=tags,
411                                    public_network_access=public_network_access,
412                                    minimal_tls_version=minimal_tls_version)
413
414    if assign_identity:
415        if server_module_path.find('postgres'):
416            if instance.identity is None:
417                instance.identity = postgresql.models.ResourceIdentity(type=postgresql.models.IdentityType.system_assigned.value)
418            params.identity = instance.identity
419        elif server_module_path.find('mysql'):
420            if instance.identity is None:
421                instance.identity = mysql.models.ResourceIdentity(type=mysql.models.IdentityType.system_assigned.value)
422            params.identity = instance.identity
423
424    return params
425
426
427def _server_mysql_upgrade(cmd, client, resource_group_name, server_name, target_server_version):
428    parameters = mysql.models.ServerUpgradeParameters(
429        target_server_version=target_server_version
430    )
431
432    client.begin_upgrade(resource_group_name, server_name, parameters)
433
434
435def _server_mariadb_get(cmd, resource_group_name, server_name):
436    client = get_mariadb_management_client(cmd.cli_ctx)
437    return client.servers.get(resource_group_name, server_name)
438
439
440def _server_mysql_get(cmd, resource_group_name, server_name):
441    client = get_mysql_management_client(cmd.cli_ctx)
442    return client.servers.get(resource_group_name, server_name)
443
444
445def _server_stop(cmd, client, resource_group_name, server_name):
446    logger.warning("Server will be automatically started after 7 days "
447                   "if you do not perform a manual start operation")
448    return client.begin_stop(resource_group_name, server_name)
449
450
451def _server_postgresql_get(cmd, resource_group_name, server_name):
452    client = get_postgresql_management_client(cmd.cli_ctx)
453    return client.servers.get(resource_group_name, server_name)
454
455
456def _server_update_get(client, resource_group_name, server_name):
457    return client.get(resource_group_name, server_name)
458
459
460def _server_update_set(client, resource_group_name, server_name, parameters):
461    return client.begin_update(resource_group_name, server_name, parameters)
462
463
464def _server_delete(cmd, client, resource_group_name, server_name):
465    database_engine = 'postgres'
466    if isinstance(client, MySqlServersOperations):
467        database_engine = 'mysql'
468
469    result = client.begin_delete(resource_group_name, server_name)
470
471    if cmd.cli_ctx.local_context.is_on:
472        local_context_file = cmd.cli_ctx.local_context._get_local_context_file()  # pylint: disable=protected-access
473        local_context_file.remove_option('{}'.format(database_engine), 'server_name')
474
475    return result.result()
476
477
478def _get_sku_name(tier, family, capacity):
479    return '{}_{}_{}'.format(SKU_TIER_MAP[tier], family, str(capacity))
480
481
482def _firewall_rule_create(client, resource_group_name, server_name, firewall_rule_name, start_ip_address, end_ip_address):
483
484    parameters = {'name': firewall_rule_name, 'start_ip_address': start_ip_address, 'end_ip_address': end_ip_address}
485
486    return client.begin_create_or_update(resource_group_name, server_name, firewall_rule_name, parameters)
487
488
489def _firewall_rule_custom_getter(client, resource_group_name, server_name, firewall_rule_name):
490    return client.get(resource_group_name, server_name, firewall_rule_name)
491
492
493def _firewall_rule_custom_setter(client, resource_group_name, server_name, firewall_rule_name, parameters):
494    return client.begin_create_or_update(
495        resource_group_name,
496        server_name,
497        firewall_rule_name,
498        parameters)
499
500
501def _firewall_rule_update_custom_func(instance, start_ip_address=None, end_ip_address=None):
502    if start_ip_address is not None:
503        instance.start_ip_address = start_ip_address
504    if end_ip_address is not None:
505        instance.end_ip_address = end_ip_address
506    return instance
507
508
509def _vnet_rule_create(client, resource_group_name, server_name, virtual_network_rule_name, virtual_network_subnet_id, ignore_missing_vnet_service_endpoint=None):
510
511    parameters = {
512        'name': virtual_network_rule_name,
513        'virtual_network_subnet_id': virtual_network_subnet_id,
514        'ignore_missing_vnet_service_endpoint': ignore_missing_vnet_service_endpoint
515    }
516
517    return client.begin_create_or_update(resource_group_name, server_name, virtual_network_rule_name, parameters)
518
519
520def _custom_vnet_update_getter(client, resource_group_name, server_name, virtual_network_rule_name):
521    return client.get(resource_group_name, server_name, virtual_network_rule_name)
522
523
524def _custom_vnet_update_setter(client, resource_group_name, server_name, virtual_network_rule_name, parameters):
525    return client.begin_create_or_update(
526        resource_group_name,
527        server_name,
528        virtual_network_rule_name,
529        parameters)
530
531
532def _vnet_rule_update_custom_func(instance, virtual_network_subnet_id, ignore_missing_vnet_service_endpoint=None):
533
534    instance.virtual_network_subnet_id = virtual_network_subnet_id
535    if ignore_missing_vnet_service_endpoint is not None:
536        instance.ignore_missing_vnet_service_endpoint = ignore_missing_vnet_service_endpoint
537    return instance
538
539
540def _configuration_update(client, resource_group_name, server_name, configuration_name, value=None, source=None):
541
542    parameters = {
543        'name': configuration_name,
544        'value': value,
545        'source': source
546    }
547
548    return client.begin_create_or_update(resource_group_name, server_name, configuration_name, parameters)
549
550
551def _db_create(client, resource_group_name, server_name, database_name, charset=None, collation=None):
552
553    parameters = {
554        'name': database_name,
555        'charset': charset,
556        'collation': collation
557    }
558
559    return client.begin_create_or_update(resource_group_name, server_name, database_name, parameters)
560
561
562# Custom functions for server logs
563def _download_log_files(
564        client,
565        resource_group_name,
566        server_name,
567        file_name):
568
569    # list all files
570    files = client.list_by_server(resource_group_name, server_name)
571
572    for f in files:
573        if f.name in file_name:
574            urlretrieve(f.url, f.name)
575
576
577def _list_log_files_with_filter(client, resource_group_name, server_name, filename_contains=None,
578                                file_last_written=None, max_file_size=None):
579
580    # list all files
581    all_files = client.list_by_server(resource_group_name, server_name)
582    files = []
583
584    if file_last_written is None:
585        file_last_written = 72
586    time_line = datetime.utcnow().replace(tzinfo=tzutc()) - timedelta(hours=file_last_written)
587
588    for f in all_files:
589        if f.last_modified_time < time_line:
590            continue
591        if filename_contains is not None and re.search(filename_contains, f.name) is None:
592            continue
593        if max_file_size is not None and f.size_in_kb > max_file_size:
594            continue
595
596        del f.created_time
597        files.append(f)
598
599    return files
600
601
602# Custom functions for list servers
603def _server_list_custom_func(client, resource_group_name=None):
604    if resource_group_name:
605        return client.list_by_resource_group(resource_group_name)
606    return client.list()
607
608
609# region private_endpoint
610def _update_private_endpoint_connection_status(cmd, client, resource_group_name, server_name,
611                                               private_endpoint_connection_name, is_approved=True, description=None):  # pylint: disable=unused-argument
612    private_endpoint_connection = client.get(resource_group_name=resource_group_name, server_name=server_name,
613                                             private_endpoint_connection_name=private_endpoint_connection_name)
614    new_status = 'Approved' if is_approved else 'Rejected'
615
616    private_link_service_connection_state = {
617        'status': new_status,
618        'description': description
619    }
620
621    private_endpoint_connection.private_link_service_connection_state = private_link_service_connection_state
622
623    return client.begin_create_or_update(resource_group_name=resource_group_name,
624                                         server_name=server_name,
625                                         private_endpoint_connection_name=private_endpoint_connection_name,
626                                         parameters=private_endpoint_connection)
627
628
629def approve_private_endpoint_connection(cmd, client, resource_group_name, server_name, private_endpoint_connection_name,
630                                        description=None):
631    """Approve a private endpoint connection request for a server."""
632
633    return _update_private_endpoint_connection_status(
634        cmd, client, resource_group_name, server_name, private_endpoint_connection_name, is_approved=True,
635        description=description)
636
637
638def reject_private_endpoint_connection(cmd, client, resource_group_name, server_name, private_endpoint_connection_name,
639                                       description=None):
640    """Reject a private endpoint connection request for a server."""
641
642    return _update_private_endpoint_connection_status(
643        cmd, client, resource_group_name, server_name, private_endpoint_connection_name, is_approved=False,
644        description=description)
645
646
647def server_key_create(client, resource_group_name, server_name, kid):
648
649    """Create Server Key."""
650
651    key_name = _get_server_key_name_from_uri(kid)
652
653    parameters = {
654        'uri': kid,
655        'server_key_type': "AzureKeyVault"
656    }
657
658    return client.begin_create_or_update(server_name, key_name, resource_group_name, parameters)
659
660
661def server_key_get(client, resource_group_name, server_name, kid):
662
663    """Get Server Key."""
664
665    key_name = _get_server_key_name_from_uri(kid)
666
667    return client.get(
668        resource_group_name=resource_group_name,
669        server_name=server_name,
670        key_name=key_name)
671
672
673def server_key_delete(cmd, client, resource_group_name, server_name, kid):
674
675    """Drop Server Key."""
676    key_name = _get_server_key_name_from_uri(kid)
677
678    return client.begin_delete(
679        resource_group_name=resource_group_name,
680        server_name=server_name,
681        key_name=key_name)
682
683
684def _get_server_key_name_from_uri(uri):
685    '''
686    Gets the key's name to use as a server key.
687
688    The SQL server key API requires that the server key has a specific name
689    based on the vault, key and key version.
690    '''
691
692    match = re.match(r'https://(.)+\.(managedhsm.azure.net|managedhsm-preview.azure.net|vault.azure.net|vault-int.azure-int.net|vault.azure.cn|managedhsm.azure.cn|vault.usgovcloudapi.net|managedhsm.usgovcloudapi.net|vault.microsoftazure.de|managedhsm.microsoftazure.de|vault.cloudapi.eaglex.ic.gov|vault.cloudapi.microsoft.scloud)(:443)?\/keys/[^\/]+\/[0-9a-zA-Z]+$', uri)
693
694    if match is None:
695        raise CLIError('The provided uri is invalid. Please provide a valid Azure Key Vault key id. For example: '
696                       '"https://YourVaultName.vault.azure.net/keys/YourKeyName/01234567890123456789012345678901" or "https://YourManagedHsmRegion.YourManagedHsmName.managedhsm.azure.net/keys/YourKeyName/01234567890123456789012345678901"')
697
698    vault = uri.split('.')[0].split('/')[-1]
699    key = uri.split('/')[-2]
700    version = uri.split('/')[-1]
701    return '{}_{}_{}'.format(vault, key, version)
702
703
704def server_ad_admin_set(client, resource_group_name, server_name, login=None, sid=None):
705    '''
706    Sets a server's AD admin.
707    '''
708
709    parameters = {
710        'login': login,
711        'sid': sid,
712        'tenant_id': _get_tenant_id()
713    }
714
715    return client.begin_create_or_update(
716        server_name=server_name,
717        resource_group_name=resource_group_name,
718        properties=parameters)
719
720
721def _get_tenant_id():
722    '''
723    Gets tenantId from current subscription.
724    '''
725    profile = Profile()
726    sub = profile.get_subscription()
727    return sub['tenantId']
728# endregion
729
730
731# region new create experience
732def create_mysql_connection_string(server_name, host, database_name, user_name, password):
733    connection_kwargs = {
734        'host': host,
735        'dbname': database_name,
736        'username': user_name,
737        'servername': server_name,
738        'password': password if password is not None else '{password}'
739    }
740    return 'mysql {dbname} --host {host} --user {username}@{servername} --password={password}'.format(**connection_kwargs)
741
742
743def create_database(cmd, resource_group_name, server_name, database_name, engine_name):
744    if engine_name == 'mysql':
745        # check for existing database, create if not present
746        database_client = cf_mysql_db(cmd.cli_ctx, None)
747    elif engine_name == 'mariadb':
748        database_client = cf_mariadb_db(cmd.cli_ctx, None)
749    parameters = {
750        'name': database_name,
751        'charset': 'utf8'
752    }
753    try:
754        database_client.get(resource_group_name, server_name, database_name)
755    except ResourceNotFoundError:
756        logger.warning('Creating %s database \'%s\'...', engine_name, database_name)
757        database_client.begin_create_or_update(resource_group_name, server_name, database_name, parameters)
758
759
760def form_response(server_result, password, host, connection_string, database_name=None, firewall_id=None):
761    result = todict(server_result)
762    result['connectionString'] = connection_string
763    result['password'] = password
764    if firewall_id is not None:
765        result['firewallName'] = firewall_id
766    if database_name is not None:
767        result['databaseName'] = database_name
768    return result
769
770
771def create_postgresql_connection_string(server_name, host, user, password):
772    connection_kwargs = {
773        'user': user,
774        'host': host,
775        'servername': server_name,
776        'password': password if password is not None else '{password}'
777    }
778    return 'postgres://{user}%40{servername}:{password}@{host}/postgres?sslmode=require'.format(**connection_kwargs)
779
780
781def check_server_name_availability(check_name_client, parameters):
782    server_availability = check_name_client.execute(parameters)
783    if not server_availability.name_available:
784        raise CLIError("The server name '{}' already exists.Please re-run command with some "
785                       "other server name.".format(parameters.name))
786    return True
787
788
789def update_local_contexts(cmd, provider, server_name, resource_group_name, location, user):
790    engine = 'postgres'
791    if provider == 'Microsoft.DBforMySQL':
792        engine = 'mysql'
793    elif provider == 'Microsoft.DBforMariaDB':
794        engine = 'mariadb'
795
796    if cmd.cli_ctx.local_context.is_on:
797        cmd.cli_ctx.local_context.set([engine], 'server_name',
798                                      server_name)  # Setting the server name in the local context
799        cmd.cli_ctx.local_context.set([engine], 'administrator_login',
800                                      user)  # Setting the server name in the local context
801        cmd.cli_ctx.local_context.set([ALL], 'location',
802                                      location)  # Setting the location in the local context
803        cmd.cli_ctx.local_context.set([ALL], 'resource_group_name', resource_group_name)
804
805
806def get_connection_string(cmd, client, server_name='{server}', database_name='{database}', administrator_login='{username}', administrator_login_password='{password}'):
807    provider = 'MySQL'
808    if isinstance(client, PostgreSQLLocationOperations):
809        provider = 'PostgreSQL'
810    elif isinstance(client, MariaDBLocationOperations):
811        provider = 'MariaDB'
812
813    if provider == 'MySQL':
814        server_endpoint = cmd.cli_ctx.cloud.suffixes.mysql_server_endpoint
815        host = '{}{}'.format(server_name, server_endpoint)
816        result = {
817            'mysql_cmd': "mysql {database} --host {host} --user {user}@{server} --password={password}",
818            'ado.net': "Server={host}; Port=3306; Database={database}; Uid={user}@{server}; Pwd={password}",
819            'jdbc': "jdbc:mysql://{host}:3306/{database}?user={user}@{server}&password={password}",
820            'node.js': "var conn = mysql.createConnection({{host: '{host}', user: '{user}@{server}',"
821                       "password: {password}, database: {database}, port: 3306}});",
822            'php': "host={host} port=3306 dbname={database} user={user}@{server} password={password}",
823            'python': "cnx = mysql.connector.connect(user='{user}@{server}', password='{password}', host='{host}', "
824                      "port=3306, database='{database}')",
825            'ruby': "client = Mysql2::Client.new(username: '{user}@{server}', password: '{password}', "
826                    "database: '{database}', host: '{host}', port: 3306)"
827        }
828
829        connection_kwargs = {
830            'host': host,
831            'user': administrator_login,
832            'password': administrator_login_password if administrator_login_password is not None else '{password}',
833            'database': database_name,
834            'server': server_name
835        }
836
837        for k, v in result.items():
838            result[k] = v.format(**connection_kwargs)
839
840    if provider == 'PostgreSQL':
841        server_endpoint = cmd.cli_ctx.cloud.suffixes.postgresql_server_endpoint
842        host = '{}{}'.format(server_name, server_endpoint)
843        result = {
844            'psql_cmd': "postgresql://{user}@{server}:{password}@{host}/{database}?sslmode=require",
845            'C++ (libpq)': "host={host} port=5432 dbname={database} user={user}@{server} password={password} sslmode=require",
846            'ado.net': "Server={host};Database={database};Port=5432;User Id={user}@{server};Password={password};",
847            'jdbc': "jdbc:postgresql://{host}:5432/{database}?user={user}@{server}&password={password}",
848            'node.js': "var client = new pg.Client('postgres://{user}@{server}:{password}@{host}:5432/{database}');",
849            'php': "host={host} port=5432 dbname={database} user={user}@{server} password={password}",
850            'python': "cnx = psycopg2.connect(database='{database}', user='{user}@{server}', host='{host}', password='{password}', "
851                      "port='5432')",
852            'ruby': "cnx = PG::Connection.new(:host => '{host}', :user => '{user}@{server}', :dbname => '{database}', "
853                    ":port => '5432', :password => '{password}')"
854        }
855
856        connection_kwargs = {
857            'host': host,
858            'user': administrator_login,
859            'password': administrator_login_password if administrator_login_password is not None else '{password}',
860            'database': database_name,
861            'server': server_name
862        }
863
864        for k, v in result.items():
865            result[k] = v.format(**connection_kwargs)
866
867    if provider == 'MariaDB':
868        server_endpoint = cmd.cli_ctx.cloud.suffixes.mariadb_server_endpoint
869        host = '{}{}'.format(server_name, server_endpoint)
870        result = {
871            'ado.net': "Server={host}; Port=3306; Database={database}; Uid={user}@{server}; Pwd={password}",
872            'jdbc': "jdbc:mariadb://{host}:3306/{database}?user={user}@{server}&password={password}",
873            'node.js': "var conn = mysql.createConnection({{host: '{host}', user: '{user}@{server}',"
874                       "password: {password}, database: {database}, port: 3306}});",
875            'php': "host={host} port=3306 dbname={database} user={user}@{server} password={password}",
876            'python': "cnx = mysql.connector.connect(user='{user}@{server}', password='{password}', host='{host}', "
877                      "port=3306, database='{database}')",
878            'ruby': "client = Mysql2::Client.new(username: '{user}@{server}', password: '{password}', "
879                    "database: '{database}', host: '{host}', port: 3306)"
880        }
881
882        connection_kwargs = {
883            'host': host,
884            'user': administrator_login,
885            'password': administrator_login_password if administrator_login_password is not None else '{password}',
886            'database': database_name,
887            'server': server_name
888        }
889
890        for k, v in result.items():
891            result[k] = v.format(**connection_kwargs)
892
893    return {
894        'connectionStrings': result
895    }
896